From 1e67c90e2caceeff82d09793d1ef5fa0300d219b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Jan 2017 12:04:37 -0800 Subject: Initial open-source release of XLA: Accelerated Linear Algebra. XLA is a compiler-based linear algebra execution engine that targets CPUs, GPUs and custom accelerators. XLA is still experimental; we are releasing it early to get the community involved. Change: 143990941 --- configure | 20 + tensorflow/BUILD | 20 + tensorflow/compiler/aot/BUILD | 218 + tensorflow/compiler/aot/benchmark.cc | 138 + tensorflow/compiler/aot/benchmark.h | 70 + tensorflow/compiler/aot/benchmark_main.template | 51 + tensorflow/compiler/aot/benchmark_test.cc | 46 + tensorflow/compiler/aot/codegen.cc | 579 ++ tensorflow/compiler/aot/codegen.h | 53 + tensorflow/compiler/aot/codegen_test.cc | 137 + tensorflow/compiler/aot/codegen_test_h.golden | 268 + tensorflow/compiler/aot/compile.cc | 416 ++ tensorflow/compiler/aot/compile.h | 92 + tensorflow/compiler/aot/flags.cc | 72 + tensorflow/compiler/aot/flags.h | 48 + tensorflow/compiler/aot/runtime.cc | 98 + tensorflow/compiler/aot/runtime.h | 58 + tensorflow/compiler/aot/runtime_test.cc | 125 + tensorflow/compiler/aot/test.cc | 94 + .../compiler/aot/test_graph_tfadd.config.pbtxt | 16 + tensorflow/compiler/aot/test_graph_tfadd.pbtxt | 63 + tensorflow/compiler/aot/tests/BUILD | 146 + tensorflow/compiler/aot/tests/make_test_graphs.py | 119 + .../aot/tests/test_graph_tfadd.config.pbtxt | 16 + .../tests/test_graph_tfadd_with_ckpt.config.pbtxt | 10 + .../aot/tests/test_graph_tfgather.config.pbtxt | 16 + .../aot/tests/test_graph_tfmatmul.config.pbtxt | 18 + .../tests/test_graph_tfmatmulandadd.config.pbtxt | 25 + tensorflow/compiler/aot/tests/tfcompile_test.cc | 381 ++ tensorflow/compiler/aot/tfcompile.bzl | 285 + tensorflow/compiler/aot/tfcompile.proto | 43 + tensorflow/compiler/aot/tfcompile_main.cc | 142 + tensorflow/compiler/aot/tfcompile_util.cc | 119 + tensorflow/compiler/aot/tfcompile_util.h | 36 + tensorflow/compiler/aot/tfcompile_util_test.cc | 185 + tensorflow/compiler/jit/BUILD | 282 + .../compiler/jit/build_xla_launch_ops_pass.cc | 215 + .../compiler/jit/build_xla_launch_ops_pass.h | 31 + tensorflow/compiler/jit/defs.cc | 22 + tensorflow/compiler/jit/defs.h | 29 + .../compiler/jit/encapsulate_subgraphs_pass.cc | 660 +++ .../compiler/jit/encapsulate_subgraphs_pass.h | 86 + .../jit/encapsulate_subgraphs_pass_test.cc | 397 ++ tensorflow/compiler/jit/graph_to_functiondef.cc | 274 + tensorflow/compiler/jit/graph_to_functiondef.h | 33 + .../compiler/jit/graph_to_functiondef_test.cc | 87 + tensorflow/compiler/jit/graphcycles/BUILD | 41 + tensorflow/compiler/jit/graphcycles/graphcycles.cc | 391 ++ tensorflow/compiler/jit/graphcycles/graphcycles.h | 128 + .../compiler/jit/graphcycles/graphcycles_test.cc | 515 ++ .../jit/jit_compilation_pass_registration.cc | 37 + tensorflow/compiler/jit/legacy_flags/BUILD | 67 + .../encapsulate_subgraphs_pass_flags.cc | 63 + .../encapsulate_subgraphs_pass_flags.h | 50 + .../mark_for_compilation_pass_flags.cc | 76 + .../legacy_flags/mark_for_compilation_pass_flags.h | 59 + .../jit/legacy_flags/parallel_check_op_flags.cc | 68 + .../jit/legacy_flags/parallel_check_op_flags.h | 52 + .../compiler/jit/mark_for_compilation_pass.cc | 534 ++ .../compiler/jit/mark_for_compilation_pass.h | 55 + .../compiler/jit/mark_for_compilation_pass_test.cc | 357 ++ tensorflow/compiler/jit/parallel_check_op.cc | 154 + tensorflow/compiler/jit/xla_compilation_cache.cc | 199 + tensorflow/compiler/jit/xla_compilation_cache.h | 112 + tensorflow/compiler/jit/xla_cpu_device.cc | 60 + tensorflow/compiler/jit/xla_device.cc | 219 + tensorflow/compiler/jit/xla_device.h | 120 + tensorflow/compiler/jit/xla_device_context.cc | 181 + tensorflow/compiler/jit/xla_device_context.h | 92 + tensorflow/compiler/jit/xla_device_launch_op.cc | 171 + tensorflow/compiler/jit/xla_device_launch_op.h | 50 + tensorflow/compiler/jit/xla_device_ops.cc | 36 + tensorflow/compiler/jit/xla_device_ops.h | 118 + tensorflow/compiler/jit/xla_gpu_device.cc | 65 + tensorflow/compiler/jit/xla_local_launch_op.cc | 342 ++ tensorflow/compiler/jit/xla_local_launch_op.h | 55 + tensorflow/compiler/tests/BUILD | 352 ++ tensorflow/compiler/tests/binary_ops_test.py | 749 +++ tensorflow/compiler/tests/build_defs.bzl | 78 + tensorflow/compiler/tests/clustering_test.py | 102 + tensorflow/compiler/tests/concat_ops_test.py | 374 ++ tensorflow/compiler/tests/conv2d_test.py | 526 ++ .../compiler/tests/depthwise_conv2d_test_kernel.cc | 30 + tensorflow/compiler/tests/dynamic_stitch_test.py | 86 + tensorflow/compiler/tests/function_test.py | 130 + tensorflow/compiler/tests/jit_test.py | 459 ++ tensorflow/compiler/tests/lrn_ops_test.py | 129 + tensorflow/compiler/tests/lstm.py | 158 + .../tests/lstm_layer_inference.config.pbtxt | 20 + .../compiler/tests/lstm_layer_inference.pbtxt | 5828 ++++++++++++++++++++ tensorflow/compiler/tests/lstm_test.py | 293 + tensorflow/compiler/tests/nary_ops_test.py | 209 + tensorflow/compiler/tests/nullary_ops_test.py | 61 + tensorflow/compiler/tests/pooling_ops_test.py | 511 ++ tensorflow/compiler/tests/randomized_tests.cc | 2097 +++++++ tensorflow/compiler/tests/reduce_ops_test.py | 125 + tensorflow/compiler/tests/ternary_ops_test.py | 110 + tensorflow/compiler/tests/unary_ops_test.py | 346 ++ tensorflow/compiler/tests/xla_device_test.py | 81 + tensorflow/compiler/tests/xla_test.py | 148 + tensorflow/compiler/tf2xla/BUILD | 193 + tensorflow/compiler/tf2xla/const_analysis.cc | 139 + tensorflow/compiler/tf2xla/const_analysis.h | 33 + tensorflow/compiler/tf2xla/const_analysis_test.cc | 83 + tensorflow/compiler/tf2xla/dump_graph.cc | 78 + tensorflow/compiler/tf2xla/dump_graph.h | 50 + tensorflow/compiler/tf2xla/dump_graph_flags.cc | 63 + tensorflow/compiler/tf2xla/dump_graph_flags.h | 48 + tensorflow/compiler/tf2xla/kernels/BUILD | 177 + .../compiler/tf2xla/kernels/aggregate_ops.cc | 47 + .../compiler/tf2xla/kernels/batch_matmul_op.cc | 141 + tensorflow/compiler/tf2xla/kernels/bcast_ops.cc | 87 + tensorflow/compiler/tf2xla/kernels/bias_ops.cc | 119 + tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 158 + tensorflow/compiler/tf2xla/kernels/cast_op.cc | 71 + tensorflow/compiler/tf2xla/kernels/concat_op.cc | 210 + tensorflow/compiler/tf2xla/kernels/conv_ops.cc | 373 ++ tensorflow/compiler/tf2xla/kernels/cwise_ops.cc | 177 + tensorflow/compiler/tf2xla/kernels/cwise_ops.h | 109 + .../compiler/tf2xla/kernels/declaration_op.cc | 127 + .../compiler/tf2xla/kernels/depthwise_conv_ops.cc | 235 + tensorflow/compiler/tf2xla/kernels/diag_op.cc | 255 + .../compiler/tf2xla/kernels/dynamic_stitch_op.cc | 200 + tensorflow/compiler/tf2xla/kernels/fill_op.cc | 74 + tensorflow/compiler/tf2xla/kernels/function_ops.cc | 110 + tensorflow/compiler/tf2xla/kernels/gather_op.cc | 104 + .../tf2xla/kernels/gather_op_kernel_float_int32.cc | 69 + .../tf2xla/kernels/gather_op_kernel_float_int64.cc | 69 + tensorflow/compiler/tf2xla/kernels/identity_op.cc | 39 + tensorflow/compiler/tf2xla/kernels/index_ops.cc | 142 + .../kernels/index_ops_kernel_argmax_float_1d.cc | 49 + .../kernels/index_ops_kernel_argmax_float_2d.cc | 51 + tensorflow/compiler/tf2xla/kernels/l2loss_op.cc | 53 + tensorflow/compiler/tf2xla/kernels/lrn_ops.cc | 173 + tensorflow/compiler/tf2xla/kernels/matmul_op.cc | 88 + tensorflow/compiler/tf2xla/kernels/no_op.cc | 24 + tensorflow/compiler/tf2xla/kernels/pack_op.cc | 93 + tensorflow/compiler/tf2xla/kernels/pad_op.cc | 80 + tensorflow/compiler/tf2xla/kernels/pooling_ops.cc | 374 ++ tensorflow/compiler/tf2xla/kernels/random_ops.cc | 116 + .../compiler/tf2xla/kernels/reduction_ops.cc | 157 + tensorflow/compiler/tf2xla/kernels/reduction_ops.h | 71 + .../tf2xla/kernels/reduction_ops_common.cc | 150 + tensorflow/compiler/tf2xla/kernels/relu_op.cc | 93 + tensorflow/compiler/tf2xla/kernels/reshape_op.cc | 101 + tensorflow/compiler/tf2xla/kernels/retval_op.cc | 79 + tensorflow/compiler/tf2xla/kernels/select_op.cc | 90 + tensorflow/compiler/tf2xla/kernels/sequence_ops.cc | 213 + tensorflow/compiler/tf2xla/kernels/shape_op.cc | 245 + tensorflow/compiler/tf2xla/kernels/slice_op.cc | 121 + tensorflow/compiler/tf2xla/kernels/softmax_op.cc | 152 + tensorflow/compiler/tf2xla/kernels/split_op.cc | 208 + .../compiler/tf2xla/kernels/strided_slice_op.cc | 223 + tensorflow/compiler/tf2xla/kernels/tile_ops.cc | 128 + tensorflow/compiler/tf2xla/kernels/transpose_op.cc | 134 + tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 70 + tensorflow/compiler/tf2xla/kernels/unpack_op.cc | 90 + tensorflow/compiler/tf2xla/literal_util.cc | 65 + tensorflow/compiler/tf2xla/literal_util.h | 42 + tensorflow/compiler/tf2xla/literal_util_test.cc | 71 + tensorflow/compiler/tf2xla/op_registrations.cc | 502 ++ tensorflow/compiler/tf2xla/shape_util.cc | 54 + tensorflow/compiler/tf2xla/shape_util.h | 38 + tensorflow/compiler/tf2xla/str_util.cc | 44 + tensorflow/compiler/tf2xla/str_util.h | 46 + tensorflow/compiler/tf2xla/str_util_test.cc | 90 + tensorflow/compiler/tf2xla/type_util.cc | 68 + tensorflow/compiler/tf2xla/type_util.h | 30 + .../compiler/tf2xla/xla_compilation_device.cc | 203 + .../compiler/tf2xla/xla_compilation_device.h | 214 + tensorflow/compiler/tf2xla/xla_compiler.cc | 405 ++ tensorflow/compiler/tf2xla/xla_compiler.h | 203 + tensorflow/compiler/tf2xla/xla_context.cc | 331 ++ tensorflow/compiler/tf2xla/xla_context.h | 277 + tensorflow/compiler/tf2xla/xla_helpers.cc | 142 + tensorflow/compiler/tf2xla/xla_helpers.h | 73 + .../compiler/tf2xla/xla_local_runtime_context.h | 55 + tensorflow/compiler/tf2xla/xla_op_kernel.cc | 253 + tensorflow/compiler/tf2xla/xla_op_kernel.h | 174 + tensorflow/compiler/xla/.clang-format | 3 + tensorflow/compiler/xla/BUILD | 561 ++ tensorflow/compiler/xla/README.md | 1 + tensorflow/compiler/xla/array2d.cc | 36 + tensorflow/compiler/xla/array2d.h | 165 + tensorflow/compiler/xla/array2d_test.cc | 132 + tensorflow/compiler/xla/array3d.h | 127 + tensorflow/compiler/xla/array3d_test.cc | 93 + tensorflow/compiler/xla/array4d.h | 272 + tensorflow/compiler/xla/array4d_test.cc | 180 + tensorflow/compiler/xla/client/BUILD | 175 + tensorflow/compiler/xla/client/client.cc | 479 ++ tensorflow/compiler/xla/client/client.h | 202 + tensorflow/compiler/xla/client/client_library.cc | 107 + tensorflow/compiler/xla/client/client_library.h | 103 + tensorflow/compiler/xla/client/computation.cc | 67 + tensorflow/compiler/xla/client/computation.h | 76 + .../compiler/xla/client/computation_builder.cc | 1539 ++++++ .../compiler/xla/client/computation_builder.h | 783 +++ tensorflow/compiler/xla/client/global_data.cc | 42 + tensorflow/compiler/xla/client/global_data.h | 46 + tensorflow/compiler/xla/client/lib/BUILD | 60 + tensorflow/compiler/xla/client/lib/arithmetic.cc | 67 + tensorflow/compiler/xla/client/lib/arithmetic.h | 45 + tensorflow/compiler/xla/client/lib/testing.cc | 59 + tensorflow/compiler/xla/client/lib/testing.h | 43 + tensorflow/compiler/xla/client/local_client.cc | 371 ++ tensorflow/compiler/xla/client/local_client.h | 263 + tensorflow/compiler/xla/client/padding.cc | 122 + tensorflow/compiler/xla/client/padding.h | 58 + tensorflow/compiler/xla/client/padding_test.cc | 91 + tensorflow/compiler/xla/device_util.h | 39 + tensorflow/compiler/xla/differential_set.h | 63 + tensorflow/compiler/xla/differential_set_test.cc | 51 + tensorflow/compiler/xla/executable_run_options.cc | 70 + tensorflow/compiler/xla/executable_run_options.h | 87 + tensorflow/compiler/xla/index_util.cc | 126 + tensorflow/compiler/xla/index_util.h | 69 + tensorflow/compiler/xla/index_util_test.cc | 159 + tensorflow/compiler/xla/layout_util.cc | 363 ++ tensorflow/compiler/xla/layout_util.h | 153 + tensorflow/compiler/xla/layout_util_test.cc | 246 + tensorflow/compiler/xla/legacy_flags/BUILD | 267 + .../xla/legacy_flags/alias_analysis_flags.cc | 62 + .../xla/legacy_flags/alias_analysis_flags.h | 46 + .../compiler/xla/legacy_flags/backend_flags.cc | 63 + .../compiler/xla/legacy_flags/backend_flags.h | 46 + .../xla/legacy_flags/buffer_assignment_flags.cc | 63 + .../xla/legacy_flags/buffer_assignment_flags.h | 46 + .../xla/legacy_flags/compiler_functor_flags.cc | 61 + .../xla/legacy_flags/compiler_functor_flags.h | 47 + .../xla/legacy_flags/convolution_thunk_flags.cc | 63 + .../xla/legacy_flags/convolution_thunk_flags.h | 47 + .../xla/legacy_flags/cpu_compiler_flags.cc | 76 + .../compiler/xla/legacy_flags/cpu_compiler_flags.h | 54 + .../compiler/xla/legacy_flags/cpu_runtime_flags.cc | 71 + .../compiler/xla/legacy_flags/cpu_runtime_flags.h | 51 + .../xla/legacy_flags/gpu_backend_lib_flags.cc | 91 + .../xla/legacy_flags/gpu_backend_lib_flags.h | 56 + .../xla/legacy_flags/gpu_compiler_flags.cc | 73 + .../compiler/xla/legacy_flags/gpu_compiler_flags.h | 54 + .../xla/legacy_flags/hlo_graph_dumper_flags.cc | 63 + .../xla/legacy_flags/hlo_graph_dumper_flags.h | 47 + .../xla/legacy_flags/hlo_pass_pipeline_flags.cc | 62 + .../xla/legacy_flags/hlo_pass_pipeline_flags.h | 48 + .../xla/legacy_flags/hlo_test_base_flags.cc | 63 + .../xla/legacy_flags/hlo_test_base_flags.h | 47 + .../compiler/xla/legacy_flags/layout_util_flags.cc | 107 + .../compiler/xla/legacy_flags/layout_util_flags.h | 62 + .../xla/legacy_flags/llvm_backend_flags.cc | 67 + .../compiler/xla/legacy_flags/llvm_backend_flags.h | 58 + .../compiler/xla/legacy_flags/llvm_util_flags.cc | 63 + .../compiler/xla/legacy_flags/llvm_util_flags.h | 46 + .../xla/legacy_flags/parse_flags_from_env.cc | 206 + .../xla/legacy_flags/parse_flags_from_env.h | 66 + .../xla/legacy_flags/parse_flags_from_env_test.cc | 190 + .../compiler/xla/legacy_flags/service_flags.cc | 100 + .../compiler/xla/legacy_flags/service_flags.h | 69 + .../xla/legacy_flags/stream_assignment_flags.cc | 63 + .../xla/legacy_flags/stream_assignment_flags.h | 47 + tensorflow/compiler/xla/legacy_flags/util_flags.cc | 62 + tensorflow/compiler/xla/legacy_flags/util_flags.h | 45 + tensorflow/compiler/xla/literal_util.cc | 989 ++++ tensorflow/compiler/xla/literal_util.h | 1004 ++++ tensorflow/compiler/xla/literal_util_test.cc | 622 +++ tensorflow/compiler/xla/map_util.h | 65 + tensorflow/compiler/xla/packed_literal_reader.cc | 92 + tensorflow/compiler/xla/packed_literal_reader.h | 59 + tensorflow/compiler/xla/port/BUILD | 33 + tensorflow/compiler/xla/port/initialize.h | 39 + tensorflow/compiler/xla/primitive_util.cc | 133 + tensorflow/compiler/xla/primitive_util.h | 157 + tensorflow/compiler/xla/protobuf_util.cc | 35 + tensorflow/compiler/xla/protobuf_util.h | 35 + tensorflow/compiler/xla/ptr_util.h | 80 + tensorflow/compiler/xla/reference_util.cc | 540 ++ tensorflow/compiler/xla/reference_util.h | 382 ++ tensorflow/compiler/xla/reference_util_test.cc | 306 + tensorflow/compiler/xla/service/BUILD | 1216 ++++ .../compiler/xla/service/algebraic_simplifier.cc | 938 ++++ .../compiler/xla/service/algebraic_simplifier.h | 56 + .../xla/service/algebraic_simplifier_test.cc | 1368 +++++ .../compiler/xla/service/allocation_tracker.cc | 215 + .../compiler/xla/service/allocation_tracker.h | 178 + tensorflow/compiler/xla/service/backend.cc | 237 + tensorflow/compiler/xla/service/backend.h | 191 + .../compiler/xla/service/buffer_assignment.cc | 777 +++ .../compiler/xla/service/buffer_assignment.h | 358 ++ .../compiler/xla/service/buffer_assignment_test.cc | 1051 ++++ tensorflow/compiler/xla/service/buffer_liveness.cc | 259 + tensorflow/compiler/xla/service/buffer_liveness.h | 215 + .../compiler/xla/service/buffer_liveness_test.cc | 487 ++ tensorflow/compiler/xla/service/channel_tracker.cc | 91 + tensorflow/compiler/xla/service/channel_tracker.h | 94 + .../compiler/xla/service/compilation_cache.cc | 78 + .../compiler/xla/service/compilation_cache.h | 78 + tensorflow/compiler/xla/service/compiler.cc | 96 + tensorflow/compiler/xla/service/compiler.h | 172 + .../compiler/xla/service/computation_layout.cc | 57 + .../compiler/xla/service/computation_layout.h | 83 + .../compiler/xla/service/computation_tracker.cc | 204 + .../compiler/xla/service/computation_tracker.h | 139 + tensorflow/compiler/xla/service/copy_insertion.cc | 439 ++ tensorflow/compiler/xla/service/copy_insertion.h | 54 + .../compiler/xla/service/copy_insertion_test.cc | 1153 ++++ tensorflow/compiler/xla/service/cpu/BUILD | 529 ++ tensorflow/compiler/xla/service/cpu/build_defs.bzl | 11 + .../compiler/xla/service/cpu/compiler_functor.cc | 220 + .../compiler/xla/service/cpu/compiler_functor.h | 69 + .../xla/service/cpu/conv_canonicalization.cc | 148 + .../xla/service/cpu/conv_canonicalization.h | 44 + .../xla/service/cpu/conv_canonicalization_test.cc | 146 + .../compiler/xla/service/cpu/cpu_compiler.cc | 631 +++ tensorflow/compiler/xla/service/cpu/cpu_compiler.h | 148 + .../compiler/xla/service/cpu/cpu_executable.cc | 477 ++ .../compiler/xla/service/cpu/cpu_executable.h | 150 + .../xla/service/cpu/cpu_instruction_fusion.cc | 44 + .../xla/service/cpu/cpu_instruction_fusion.h | 37 + .../service/cpu/cpu_parallelization_preparation.cc | 120 + .../service/cpu/cpu_parallelization_preparation.h | 46 + tensorflow/compiler/xla/service/cpu/cpu_runtime.cc | 52 + tensorflow/compiler/xla/service/cpu/cpu_runtime.h | 91 + .../compiler/xla/service/cpu/cpu_runtime_avx.cc | 36 + .../compiler/xla/service/cpu/cpu_runtime_avx.h | 50 + .../compiler/xla/service/cpu/cpu_runtime_sse4_1.cc | 47 + .../compiler/xla/service/cpu/cpu_runtime_sse4_1.h | 50 + .../compiler/xla/service/cpu/cpu_runtime_test.cc | 138 + .../compiler/xla/service/cpu/disassembler.cc | 182 + tensorflow/compiler/xla/service/cpu/disassembler.h | 63 + .../compiler/xla/service/cpu/dot_op_emitter.cc | 346 ++ .../compiler/xla/service/cpu/dot_op_emitter.h | 90 + .../xla/service/cpu/elemental_ir_emitter.cc | 68 + .../xla/service/cpu/elemental_ir_emitter.h | 43 + .../compiler/xla/service/cpu/infeed_manager.cc | 72 + .../compiler/xla/service/cpu/infeed_manager.h | 95 + .../xla/service/cpu/infeed_manager_test.cc | 102 + .../compiler/xla/service/cpu/ir_emission_utils.cc | 127 + .../compiler/xla/service/cpu/ir_emission_utils.h | 32 + tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 1774 ++++++ tensorflow/compiler/xla/service/cpu/ir_emitter.h | 402 ++ .../compiler/xla/service/cpu/layout_assignment.cc | 124 + .../compiler/xla/service/cpu/layout_assignment.h | 41 + .../xla/service/cpu/parallel_cpu_executable.cc | 365 ++ .../xla/service/cpu/parallel_cpu_executable.h | 124 + .../compiler/xla/service/cpu/runtime_conv2d.cc | 43 + .../compiler/xla/service/cpu/runtime_conv2d.h | 39 + .../compiler/xla/service/cpu/runtime_conv2d_impl.h | 87 + .../compiler/xla/service/cpu/runtime_matmul.cc | 81 + .../compiler/xla/service/cpu/runtime_matmul.h | 42 + .../service/cpu/runtime_single_threaded_conv2d.cc | 39 + .../service/cpu/runtime_single_threaded_conv2d.h | 39 + .../service/cpu/runtime_single_threaded_matmul.cc | 73 + .../service/cpu/runtime_single_threaded_matmul.h | 42 + .../compiler/xla/service/cpu/sample_harness.cc | 75 + .../compiler/xla/service/cpu/simple_orc_jit.cc | 189 + .../compiler/xla/service/cpu/simple_orc_jit.h | 88 + .../compiler/xla/service/cpu_transfer_manager.cc | 108 + .../compiler/xla/service/cpu_transfer_manager.h | 47 + .../xla/service/device_memory_allocator.cc | 77 + .../compiler/xla/service/device_memory_allocator.h | 84 + tensorflow/compiler/xla/service/dfs_hlo_visitor.cc | 78 + tensorflow/compiler/xla/service/dfs_hlo_visitor.h | 289 + .../xla/service/dfs_hlo_visitor_with_default.h | 226 + .../compiler/xla/service/elemental_ir_emitter.cc | 934 ++++ .../compiler/xla/service/elemental_ir_emitter.h | 118 + tensorflow/compiler/xla/service/executable.cc | 82 + tensorflow/compiler/xla/service/executable.h | 168 + .../compiler/xla/service/execution_tracker.cc | 95 + .../compiler/xla/service/execution_tracker.h | 105 + .../xla/service/generic_transfer_manager.cc | 183 + .../xla/service/generic_transfer_manager.h | 77 + tensorflow/compiler/xla/service/gpu/BUILD | 533 ++ .../compiler/xla/service/gpu/buffer_allocations.cc | 139 + .../compiler/xla/service/gpu/buffer_allocations.h | 113 + .../xla/service/gpu/convolution_folding.cc | 443 ++ .../compiler/xla/service/gpu/convolution_folding.h | 34 + .../xla/service/gpu/convolution_folding_test.cc | 552 ++ .../compiler/xla/service/gpu/convolution_thunk.cc | 324 ++ .../compiler/xla/service/gpu/convolution_thunk.h | 149 + .../compiler/xla/service/gpu/copy_insertion.cc | 71 + .../compiler/xla/service/gpu/copy_insertion.h | 36 + tensorflow/compiler/xla/service/gpu/copy_thunk.cc | 41 + tensorflow/compiler/xla/service/gpu/copy_thunk.h | 56 + .../xla/service/gpu/elemental_ir_emitter.cc | 396 ++ .../xla/service/gpu/elemental_ir_emitter.h | 91 + tensorflow/compiler/xla/service/gpu/for_thunk.cc | 50 + tensorflow/compiler/xla/service/gpu/for_thunk.h | 52 + tensorflow/compiler/xla/service/gpu/gemm_thunk.cc | 189 + tensorflow/compiler/xla/service/gpu/gemm_thunk.h | 71 + .../compiler/xla/service/gpu/gpu_compiler.cc | 335 ++ tensorflow/compiler/xla/service/gpu/gpu_compiler.h | 78 + .../compiler/xla/service/gpu/gpu_executable.cc | 454 ++ .../compiler/xla/service/gpu/gpu_executable.h | 130 + .../compiler/xla/service/gpu/hlo_schedule.cc | 207 + tensorflow/compiler/xla/service/gpu/hlo_schedule.h | 67 + .../compiler/xla/service/gpu/hlo_schedule_test.cc | 368 ++ .../compiler/xla/service/gpu/hlo_to_ir_bindings.cc | 168 + .../compiler/xla/service/gpu/hlo_to_ir_bindings.h | 109 + .../compiler/xla/service/gpu/instruction_fusion.cc | 90 + .../compiler/xla/service/gpu/instruction_fusion.h | 39 + .../xla/service/gpu/instruction_fusion_test.cc | 126 + .../compiler/xla/service/gpu/ir_emission_utils.cc | 200 + .../compiler/xla/service/gpu/ir_emission_utils.h | 72 + tensorflow/compiler/xla/service/gpu/ir_emitter.cc | 645 +++ tensorflow/compiler/xla/service/gpu/ir_emitter.h | 405 ++ .../compiler/xla/service/gpu/ir_emitter_context.h | 74 + .../compiler/xla/service/gpu/ir_emitter_nested.cc | 120 + .../xla/service/gpu/ir_emitter_unnested.cc | 1745 ++++++ .../compiler/xla/service/gpu/kernel_thunk.cc | 94 + tensorflow/compiler/xla/service/gpu/kernel_thunk.h | 86 + .../compiler/xla/service/gpu/layout_assignment.cc | 142 + .../compiler/xla/service/gpu/layout_assignment.h | 41 + .../xla/service/gpu/layout_assignment_test.cc | 85 + .../xla/service/gpu/llvm_gpu_backend/BUILD | 88 + .../service/gpu/llvm_gpu_backend/dump_ir_pass.cc | 103 + .../service/gpu/llvm_gpu_backend/dump_ir_pass.h | 51 + .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 489 ++ .../service/gpu/llvm_gpu_backend/gpu_backend_lib.h | 43 + .../gpu/llvm_gpu_backend/tests_data/saxpy.ll | 141 + .../xla/service/gpu/llvm_gpu_backend/utils.cc | 65 + .../xla/service/gpu/llvm_gpu_backend/utils.h | 50 + .../xla/service/gpu/llvm_gpu_backend/utils_test.cc | 55 + .../compiler/xla/service/gpu/pad_insertion.cc | 408 ++ .../compiler/xla/service/gpu/pad_insertion.h | 43 + .../xla/service/gpu/parallel_loop_emitter.cc | 98 + .../xla/service/gpu/parallel_loop_emitter.h | 58 + .../xla/service/gpu/partition_assignment.cc | 99 + .../xla/service/gpu/partition_assignment.h | 75 + .../compiler/xla/service/gpu/sequential_thunk.cc | 45 + .../compiler/xla/service/gpu/sequential_thunk.h | 54 + .../compiler/xla/service/gpu/stream_assignment.cc | 135 + .../compiler/xla/service/gpu/stream_assignment.h | 46 + .../xla/service/gpu/stream_assignment_test.cc | 132 + .../xla/service/gpu/temp_buffer_offsets.cc | 52 + .../compiler/xla/service/gpu/temp_buffer_offsets.h | 47 + tensorflow/compiler/xla/service/gpu/thunk.h | 90 + .../compiler/xla/service/gpu/thunk_schedule.cc | 163 + .../compiler/xla/service/gpu/thunk_schedule.h | 93 + tensorflow/compiler/xla/service/gpu/tuple_thunk.cc | 49 + tensorflow/compiler/xla/service/gpu/tuple_thunk.h | 60 + tensorflow/compiler/xla/service/gpu/while_thunk.cc | 74 + tensorflow/compiler/xla/service/gpu/while_thunk.h | 62 + .../compiler/xla/service/gpu/while_transformer.cc | 532 ++ .../compiler/xla/service/gpu/while_transformer.h | 43 + .../xla/service/gpu/while_transformer_test.cc | 218 + .../compiler/xla/service/graphviz_example.cc | 165 + tensorflow/compiler/xla/service/hlo_computation.cc | 520 ++ tensorflow/compiler/xla/service/hlo_computation.h | 300 + .../compiler/xla/service/hlo_computation_test.cc | 311 ++ .../compiler/xla/service/hlo_cost_analysis.cc | 350 ++ .../compiler/xla/service/hlo_cost_analysis.h | 147 + .../compiler/xla/service/hlo_cost_analysis_test.cc | 337 ++ tensorflow/compiler/xla/service/hlo_cse.cc | 134 + tensorflow/compiler/xla/service/hlo_cse.h | 46 + tensorflow/compiler/xla/service/hlo_cse_test.cc | 428 ++ tensorflow/compiler/xla/service/hlo_dce.cc | 69 + tensorflow/compiler/xla/service/hlo_dce.h | 43 + tensorflow/compiler/xla/service/hlo_dce_test.cc | 97 + .../compiler/xla/service/hlo_execution_profile.cc | 87 + .../compiler/xla/service/hlo_execution_profile.h | 71 + .../compiler/xla/service/hlo_graph_dumper.cc | 507 ++ tensorflow/compiler/xla/service/hlo_graph_dumper.h | 76 + tensorflow/compiler/xla/service/hlo_instruction.cc | 1921 +++++++ tensorflow/compiler/xla/service/hlo_instruction.h | 791 +++ .../compiler/xla/service/hlo_instruction_test.cc | 894 +++ tensorflow/compiler/xla/service/hlo_module.cc | 269 + tensorflow/compiler/xla/service/hlo_module.h | 132 + .../compiler/xla/service/hlo_module_config.cc | 53 + .../compiler/xla/service/hlo_module_config.h | 92 + tensorflow/compiler/xla/service/hlo_module_test.cc | 101 + tensorflow/compiler/xla/service/hlo_opcode.cc | 164 + tensorflow/compiler/xla/service/hlo_opcode.h | 107 + tensorflow/compiler/xla/service/hlo_opcode_test.cc | 30 + tensorflow/compiler/xla/service/hlo_pass.h | 68 + .../compiler/xla/service/hlo_pass_pipeline.cc | 64 + .../compiler/xla/service/hlo_pass_pipeline.h | 66 + tensorflow/compiler/xla/service/hlo_query.cc | 89 + tensorflow/compiler/xla/service/hlo_query.h | 63 + .../xla/service/hlo_subcomputation_unification.cc | 45 + .../xla/service/hlo_subcomputation_unification.h | 34 + .../service/hlo_subcomputation_unification_test.cc | 205 + tensorflow/compiler/xla/service/inliner.cc | 123 + tensorflow/compiler/xla/service/inliner.h | 39 + tensorflow/compiler/xla/service/inliner_test.cc | 109 + .../compiler/xla/service/instruction_fusion.cc | 295 + .../compiler/xla/service/instruction_fusion.h | 84 + .../xla/service/instruction_fusion_test.cc | 140 + .../compiler/xla/service/layout_assignment.cc | 1334 +++++ .../compiler/xla/service/layout_assignment.h | 302 + .../compiler/xla/service/layout_assignment_test.cc | 486 ++ tensorflow/compiler/xla/service/llvm_ir/BUILD | 154 + tensorflow/compiler/xla/service/llvm_ir/README.md | 2 + .../compiler/xla/service/llvm_ir/alias_analysis.cc | 195 + .../compiler/xla/service/llvm_ir/alias_analysis.h | 93 + .../xla/service/llvm_ir/fused_ir_emitter.cc | 147 + .../xla/service/llvm_ir/fused_ir_emitter.h | 94 + .../compiler/xla/service/llvm_ir/ir_array.cc | 274 + tensorflow/compiler/xla/service/llvm_ir/ir_array.h | 248 + .../compiler/xla/service/llvm_ir/llvm_loop.cc | 197 + .../compiler/xla/service/llvm_ir/llvm_loop.h | 230 + .../compiler/xla/service/llvm_ir/llvm_util.cc | 471 ++ .../compiler/xla/service/llvm_ir/llvm_util.h | 228 + .../compiler/xla/service/llvm_ir/loop_emitter.cc | 103 + .../compiler/xla/service/llvm_ir/loop_emitter.h | 79 + tensorflow/compiler/xla/service/llvm_ir/ops.cc | 100 + tensorflow/compiler/xla/service/llvm_ir/ops.h | 79 + tensorflow/compiler/xla/service/local_service.cc | 543 ++ tensorflow/compiler/xla/service/local_service.h | 185 + tensorflow/compiler/xla/service/logical_buffer.cc | 39 + tensorflow/compiler/xla/service/logical_buffer.h | 153 + tensorflow/compiler/xla/service/name_uniquer.cc | 37 + tensorflow/compiler/xla/service/name_uniquer.h | 53 + tensorflow/compiler/xla/service/platform_util.cc | 166 + tensorflow/compiler/xla/service/platform_util.h | 61 + tensorflow/compiler/xla/service/reshape_mover.cc | 120 + tensorflow/compiler/xla/service/reshape_mover.h | 36 + .../compiler/xla/service/reshape_mover_test.cc | 57 + tensorflow/compiler/xla/service/service.cc | 1428 +++++ tensorflow/compiler/xla/service/service.h | 457 ++ tensorflow/compiler/xla/service/session.proto | 91 + tensorflow/compiler/xla/service/shape_inference.cc | 1380 +++++ tensorflow/compiler/xla/service/shape_inference.h | 219 + .../compiler/xla/service/shape_inference_test.cc | 1133 ++++ tensorflow/compiler/xla/service/shaped_buffer.cc | 168 + tensorflow/compiler/xla/service/shaped_buffer.h | 137 + .../compiler/xla/service/transfer_manager.cc | 143 + tensorflow/compiler/xla/service/transfer_manager.h | 151 + .../compiler/xla/service/transfer_manager_test.cc | 159 + .../compiler/xla/service/transpose_folding.cc | 109 + .../compiler/xla/service/transpose_folding.h | 41 + .../compiler/xla/service/transpose_folding_test.cc | 149 + .../xla/service/tuple_points_to_analysis.cc | 495 ++ .../xla/service/tuple_points_to_analysis.h | 268 + .../xla/service/tuple_points_to_analysis_test.cc | 544 ++ .../compiler/xla/service/user_computation.cc | 2117 +++++++ tensorflow/compiler/xla/service/user_computation.h | 336 ++ .../xla/service/versioned_computation_handle.h | 48 + tensorflow/compiler/xla/service_interface.h | 117 + tensorflow/compiler/xla/shape_layout.cc | 78 + tensorflow/compiler/xla/shape_layout.h | 88 + tensorflow/compiler/xla/shape_tree.h | 260 + tensorflow/compiler/xla/shape_tree_test.cc | 134 + tensorflow/compiler/xla/shape_util.cc | 1024 ++++ tensorflow/compiler/xla/shape_util.h | 393 ++ tensorflow/compiler/xla/shape_util_test.cc | 506 ++ tensorflow/compiler/xla/status.h | 46 + tensorflow/compiler/xla/status_macros.cc | 170 + tensorflow/compiler/xla/status_macros.h | 220 + tensorflow/compiler/xla/status_macros_test.cc | 112 + tensorflow/compiler/xla/statusor.cc | 46 + tensorflow/compiler/xla/statusor.h | 300 + tensorflow/compiler/xla/statusor_test.cc | 645 +++ tensorflow/compiler/xla/test_helpers.cc | 69 + tensorflow/compiler/xla/test_helpers.h | 355 ++ tensorflow/compiler/xla/tests/BUILD | 1436 +++++ .../xla/tests/array_elementwise_ops_test.cc | 1662 ++++++ tensorflow/compiler/xla/tests/axpy_simple_test.cc | 90 + .../xla/tests/bad_rng_shape_validation_test.cc | 85 + .../compiler/xla/tests/batch_normalization_test.cc | 210 + .../compiler/xla/tests/binop_scaling_test.cc | 157 + .../compiler/xla/tests/broadcast_simple_test.cc | 179 + tensorflow/compiler/xla/tests/broadcast_test.cc | 286 + tensorflow/compiler/xla/tests/build_defs.bzl | 149 + tensorflow/compiler/xla/tests/call_test.cc | 115 + .../xla/tests/check_execution_arity_test.cc | 138 + .../compiler/xla/tests/client_library_test_base.cc | 263 + .../compiler/xla/tests/client_library_test_base.h | 409 ++ tensorflow/compiler/xla/tests/client_test.cc | 127 + tensorflow/compiler/xla/tests/codegen_test_base.cc | 90 + tensorflow/compiler/xla/tests/codegen_test_base.h | 56 + .../compiler/xla/tests/compilation_cache_test.cc | 218 + .../compiler/xla/tests/compute_constant_test.cc | 249 + tensorflow/compiler/xla/tests/concat_test.cc | 523 ++ tensorflow/compiler/xla/tests/constants_test.cc | 193 + tensorflow/compiler/xla/tests/convert_test.cc | 210 + .../tests/convolution_dimension_numbers_test.cc | 117 + tensorflow/compiler/xla/tests/convolution_test.cc | 361 ++ .../xla/tests/convolution_variants_test.cc | 1294 +++++ tensorflow/compiler/xla/tests/copy_test.cc | 277 + tensorflow/compiler/xla/tests/custom_call_test.cc | 148 + tensorflow/compiler/xla/tests/deallocation_test.cc | 155 + .../compiler/xla/tests/deconstruct_tuple_test.cc | 215 + .../compiler/xla/tests/dot_operation_test.cc | 387 ++ tensorflow/compiler/xla/tests/dynamic_ops_test.cc | 506 ++ tensorflow/compiler/xla/tests/floor_ceil_test.cc | 128 + tensorflow/compiler/xla/tests/fmax_test.cc | 61 + tensorflow/compiler/xla/tests/fusion_test.cc | 589 ++ tensorflow/compiler/xla/tests/hlo_test_base.cc | 204 + tensorflow/compiler/xla/tests/hlo_test_base.h | 107 + .../compiler/xla/tests/inprocess_service_test.cc | 204 + tensorflow/compiler/xla/tests/literal_test_util.cc | 566 ++ tensorflow/compiler/xla/tests/literal_test_util.h | 274 + .../compiler/xla/tests/literal_test_util_test.cc | 102 + .../compiler/xla/tests/local_client_aot_test.cc | 55 + .../xla/tests/local_client_aot_test_helper.cc | 111 + .../compiler/xla/tests/local_client_test_base.cc | 220 + .../compiler/xla/tests/local_client_test_base.h | 146 + tensorflow/compiler/xla/tests/log_test.cc | 75 + tensorflow/compiler/xla/tests/map_test.cc | 589 ++ .../compiler/xla/tests/matrix_ops_simple_test.cc | 179 + .../xla/tests/multidimensional_slice_test.cc | 74 + tensorflow/compiler/xla/tests/pad_test.cc | 420 ++ tensorflow/compiler/xla/tests/params_test.cc | 357 ++ tensorflow/compiler/xla/tests/pred_test.cc | 115 + tensorflow/compiler/xla/tests/prng_test.cc | 238 + .../xla/tests/query_inferred_shape_test.cc | 61 + tensorflow/compiler/xla/tests/reduce_test.cc | 506 ++ .../compiler/xla/tests/reduce_window_test.cc | 445 ++ tensorflow/compiler/xla/tests/replay_test.cc | 168 + .../compiler/xla/tests/reshape_motion_test.cc | 77 + tensorflow/compiler/xla/tests/reshape_test.cc | 811 +++ tensorflow/compiler/xla/tests/reverse_test.cc | 173 + .../xla/tests/round_trip_packed_literal_test.cc | 160 + .../compiler/xla/tests/round_trip_transfer_test.cc | 164 + .../compiler/xla/tests/scalar_computations_test.cc | 630 +++ .../compiler/xla/tests/select_and_scatter_test.cc | 395 ++ tensorflow/compiler/xla/tests/select_test.cc | 276 + .../compiler/xla/tests/set_return_value_test.cc | 116 + tensorflow/compiler/xla/tests/slice_test.cc | 277 + tensorflow/compiler/xla/tests/test_macros.h | 76 + tensorflow/compiler/xla/tests/test_utils.h | 115 + tensorflow/compiler/xla/tests/transpose_test.cc | 203 + tensorflow/compiler/xla/tests/tuple_test.cc | 415 ++ tensorflow/compiler/xla/tests/unary_op_test.cc | 179 + .../compiler/xla/tests/vector_ops_reduce_test.cc | 235 + .../compiler/xla/tests/vector_ops_simple_test.cc | 423 ++ tensorflow/compiler/xla/tests/while_test.cc | 395 ++ tensorflow/compiler/xla/text_literal_reader.cc | 155 + tensorflow/compiler/xla/text_literal_reader.h | 62 + .../compiler/xla/text_literal_reader_test.cc | 58 + tensorflow/compiler/xla/text_literal_writer.cc | 64 + tensorflow/compiler/xla/text_literal_writer.h | 48 + .../compiler/xla/text_literal_writer_test.cc | 52 + tensorflow/compiler/xla/tools/BUILD | 191 + .../compiler/xla/tools/convert_computation.cc | 60 + .../xla/tools/dumped_computation_to_graphviz.cc | 76 + .../tools/dumped_computation_to_operation_list.cc | 111 + .../xla/tools/dumped_computation_to_text.cc | 83 + .../xla/tools/hex_floats_to_packed_literal.cc | 76 + .../compiler/xla/tools/replay_computation.cc | 129 + tensorflow/compiler/xla/tools/show_literal.cc | 45 + tensorflow/compiler/xla/tools/show_signature.cc | 73 + tensorflow/compiler/xla/tools/show_text_literal.cc | 52 + tensorflow/compiler/xla/types.h | 37 + tensorflow/compiler/xla/util.cc | 238 + tensorflow/compiler/xla/util.h | 257 + tensorflow/compiler/xla/util_test.cc | 89 + tensorflow/compiler/xla/window_util.cc | 142 + tensorflow/compiler/xla/window_util.h | 66 + tensorflow/compiler/xla/xla.bzl | 22 + tensorflow/compiler/xla/xla.proto | 291 + tensorflow/compiler/xla/xla_data.proto | 714 +++ tensorflow/contrib/compiler/BUILD | 1 + tensorflow/core/platform/default/build_config.bzl | 12 +- tensorflow/tools/ci_build/builds/configured | 4 + tensorflow/tools/pip_package/BUILD | 3 +- third_party/llvm/llvm.BUILD | 1 + 656 files changed, 138481 insertions(+), 2 deletions(-) create mode 100644 tensorflow/compiler/aot/BUILD create mode 100644 tensorflow/compiler/aot/benchmark.cc create mode 100644 tensorflow/compiler/aot/benchmark.h create mode 100644 tensorflow/compiler/aot/benchmark_main.template create mode 100644 tensorflow/compiler/aot/benchmark_test.cc create mode 100644 tensorflow/compiler/aot/codegen.cc create mode 100644 tensorflow/compiler/aot/codegen.h create mode 100644 tensorflow/compiler/aot/codegen_test.cc create mode 100644 tensorflow/compiler/aot/codegen_test_h.golden create mode 100644 tensorflow/compiler/aot/compile.cc create mode 100644 tensorflow/compiler/aot/compile.h create mode 100644 tensorflow/compiler/aot/flags.cc create mode 100644 tensorflow/compiler/aot/flags.h create mode 100644 tensorflow/compiler/aot/runtime.cc create mode 100644 tensorflow/compiler/aot/runtime.h create mode 100644 tensorflow/compiler/aot/runtime_test.cc create mode 100644 tensorflow/compiler/aot/test.cc create mode 100644 tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt create mode 100644 tensorflow/compiler/aot/test_graph_tfadd.pbtxt create mode 100644 tensorflow/compiler/aot/tests/BUILD create mode 100644 tensorflow/compiler/aot/tests/make_test_graphs.py create mode 100644 tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt create mode 100644 tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt create mode 100644 tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt create mode 100644 tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt create mode 100644 tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt create mode 100644 tensorflow/compiler/aot/tests/tfcompile_test.cc create mode 100644 tensorflow/compiler/aot/tfcompile.bzl create mode 100644 tensorflow/compiler/aot/tfcompile.proto create mode 100644 tensorflow/compiler/aot/tfcompile_main.cc create mode 100644 tensorflow/compiler/aot/tfcompile_util.cc create mode 100644 tensorflow/compiler/aot/tfcompile_util.h create mode 100644 tensorflow/compiler/aot/tfcompile_util_test.cc create mode 100644 tensorflow/compiler/jit/BUILD create mode 100644 tensorflow/compiler/jit/build_xla_launch_ops_pass.cc create mode 100644 tensorflow/compiler/jit/build_xla_launch_ops_pass.h create mode 100644 tensorflow/compiler/jit/defs.cc create mode 100644 tensorflow/compiler/jit/defs.h create mode 100644 tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc create mode 100644 tensorflow/compiler/jit/encapsulate_subgraphs_pass.h create mode 100644 tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc create mode 100644 tensorflow/compiler/jit/graph_to_functiondef.cc create mode 100644 tensorflow/compiler/jit/graph_to_functiondef.h create mode 100644 tensorflow/compiler/jit/graph_to_functiondef_test.cc create mode 100644 tensorflow/compiler/jit/graphcycles/BUILD create mode 100644 tensorflow/compiler/jit/graphcycles/graphcycles.cc create mode 100644 tensorflow/compiler/jit/graphcycles/graphcycles.h create mode 100644 tensorflow/compiler/jit/graphcycles/graphcycles_test.cc create mode 100644 tensorflow/compiler/jit/jit_compilation_pass_registration.cc create mode 100644 tensorflow/compiler/jit/legacy_flags/BUILD create mode 100644 tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc create mode 100644 tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h create mode 100644 tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc create mode 100644 tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h create mode 100644 tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc create mode 100644 tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h create mode 100644 tensorflow/compiler/jit/mark_for_compilation_pass.cc create mode 100644 tensorflow/compiler/jit/mark_for_compilation_pass.h create mode 100644 tensorflow/compiler/jit/mark_for_compilation_pass_test.cc create mode 100644 tensorflow/compiler/jit/parallel_check_op.cc create mode 100644 tensorflow/compiler/jit/xla_compilation_cache.cc create mode 100644 tensorflow/compiler/jit/xla_compilation_cache.h create mode 100644 tensorflow/compiler/jit/xla_cpu_device.cc create mode 100644 tensorflow/compiler/jit/xla_device.cc create mode 100644 tensorflow/compiler/jit/xla_device.h create mode 100644 tensorflow/compiler/jit/xla_device_context.cc create mode 100644 tensorflow/compiler/jit/xla_device_context.h create mode 100644 tensorflow/compiler/jit/xla_device_launch_op.cc create mode 100644 tensorflow/compiler/jit/xla_device_launch_op.h create mode 100644 tensorflow/compiler/jit/xla_device_ops.cc create mode 100644 tensorflow/compiler/jit/xla_device_ops.h create mode 100644 tensorflow/compiler/jit/xla_gpu_device.cc create mode 100644 tensorflow/compiler/jit/xla_local_launch_op.cc create mode 100644 tensorflow/compiler/jit/xla_local_launch_op.h create mode 100644 tensorflow/compiler/tests/BUILD create mode 100644 tensorflow/compiler/tests/binary_ops_test.py create mode 100644 tensorflow/compiler/tests/build_defs.bzl create mode 100644 tensorflow/compiler/tests/clustering_test.py create mode 100644 tensorflow/compiler/tests/concat_ops_test.py create mode 100644 tensorflow/compiler/tests/conv2d_test.py create mode 100644 tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc create mode 100644 tensorflow/compiler/tests/dynamic_stitch_test.py create mode 100644 tensorflow/compiler/tests/function_test.py create mode 100644 tensorflow/compiler/tests/jit_test.py create mode 100644 tensorflow/compiler/tests/lrn_ops_test.py create mode 100644 tensorflow/compiler/tests/lstm.py create mode 100644 tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt create mode 100644 tensorflow/compiler/tests/lstm_layer_inference.pbtxt create mode 100644 tensorflow/compiler/tests/lstm_test.py create mode 100644 tensorflow/compiler/tests/nary_ops_test.py create mode 100644 tensorflow/compiler/tests/nullary_ops_test.py create mode 100644 tensorflow/compiler/tests/pooling_ops_test.py create mode 100644 tensorflow/compiler/tests/randomized_tests.cc create mode 100644 tensorflow/compiler/tests/reduce_ops_test.py create mode 100644 tensorflow/compiler/tests/ternary_ops_test.py create mode 100644 tensorflow/compiler/tests/unary_ops_test.py create mode 100644 tensorflow/compiler/tests/xla_device_test.py create mode 100644 tensorflow/compiler/tests/xla_test.py create mode 100644 tensorflow/compiler/tf2xla/BUILD create mode 100644 tensorflow/compiler/tf2xla/const_analysis.cc create mode 100644 tensorflow/compiler/tf2xla/const_analysis.h create mode 100644 tensorflow/compiler/tf2xla/const_analysis_test.cc create mode 100644 tensorflow/compiler/tf2xla/dump_graph.cc create mode 100644 tensorflow/compiler/tf2xla/dump_graph.h create mode 100644 tensorflow/compiler/tf2xla/dump_graph_flags.cc create mode 100644 tensorflow/compiler/tf2xla/dump_graph_flags.h create mode 100644 tensorflow/compiler/tf2xla/kernels/BUILD create mode 100644 tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/bcast_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/bias_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/binary_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/cast_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/concat_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/conv_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/cwise_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/cwise_ops.h create mode 100644 tensorflow/compiler/tf2xla/kernels/declaration_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/diag_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/fill_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/function_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/gather_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/identity_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/index_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/l2loss_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/lrn_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/matmul_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/no_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/pack_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/pad_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/pooling_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/random_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/reduction_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/reduction_ops.h create mode 100644 tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/relu_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/reshape_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/retval_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/select_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/sequence_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/shape_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/slice_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/softmax_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/split_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/tile_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/transpose_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/unary_ops.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/unpack_op.cc create mode 100644 tensorflow/compiler/tf2xla/literal_util.cc create mode 100644 tensorflow/compiler/tf2xla/literal_util.h create mode 100644 tensorflow/compiler/tf2xla/literal_util_test.cc create mode 100644 tensorflow/compiler/tf2xla/op_registrations.cc create mode 100644 tensorflow/compiler/tf2xla/shape_util.cc create mode 100644 tensorflow/compiler/tf2xla/shape_util.h create mode 100644 tensorflow/compiler/tf2xla/str_util.cc create mode 100644 tensorflow/compiler/tf2xla/str_util.h create mode 100644 tensorflow/compiler/tf2xla/str_util_test.cc create mode 100644 tensorflow/compiler/tf2xla/type_util.cc create mode 100644 tensorflow/compiler/tf2xla/type_util.h create mode 100644 tensorflow/compiler/tf2xla/xla_compilation_device.cc create mode 100644 tensorflow/compiler/tf2xla/xla_compilation_device.h create mode 100644 tensorflow/compiler/tf2xla/xla_compiler.cc create mode 100644 tensorflow/compiler/tf2xla/xla_compiler.h create mode 100644 tensorflow/compiler/tf2xla/xla_context.cc create mode 100644 tensorflow/compiler/tf2xla/xla_context.h create mode 100644 tensorflow/compiler/tf2xla/xla_helpers.cc create mode 100644 tensorflow/compiler/tf2xla/xla_helpers.h create mode 100644 tensorflow/compiler/tf2xla/xla_local_runtime_context.h create mode 100644 tensorflow/compiler/tf2xla/xla_op_kernel.cc create mode 100644 tensorflow/compiler/tf2xla/xla_op_kernel.h create mode 100644 tensorflow/compiler/xla/.clang-format create mode 100644 tensorflow/compiler/xla/BUILD create mode 100644 tensorflow/compiler/xla/README.md create mode 100644 tensorflow/compiler/xla/array2d.cc create mode 100644 tensorflow/compiler/xla/array2d.h create mode 100644 tensorflow/compiler/xla/array2d_test.cc create mode 100644 tensorflow/compiler/xla/array3d.h create mode 100644 tensorflow/compiler/xla/array3d_test.cc create mode 100644 tensorflow/compiler/xla/array4d.h create mode 100644 tensorflow/compiler/xla/array4d_test.cc create mode 100644 tensorflow/compiler/xla/client/BUILD create mode 100644 tensorflow/compiler/xla/client/client.cc create mode 100644 tensorflow/compiler/xla/client/client.h create mode 100644 tensorflow/compiler/xla/client/client_library.cc create mode 100644 tensorflow/compiler/xla/client/client_library.h create mode 100644 tensorflow/compiler/xla/client/computation.cc create mode 100644 tensorflow/compiler/xla/client/computation.h create mode 100644 tensorflow/compiler/xla/client/computation_builder.cc create mode 100644 tensorflow/compiler/xla/client/computation_builder.h create mode 100644 tensorflow/compiler/xla/client/global_data.cc create mode 100644 tensorflow/compiler/xla/client/global_data.h create mode 100644 tensorflow/compiler/xla/client/lib/BUILD create mode 100644 tensorflow/compiler/xla/client/lib/arithmetic.cc create mode 100644 tensorflow/compiler/xla/client/lib/arithmetic.h create mode 100644 tensorflow/compiler/xla/client/lib/testing.cc create mode 100644 tensorflow/compiler/xla/client/lib/testing.h create mode 100644 tensorflow/compiler/xla/client/local_client.cc create mode 100644 tensorflow/compiler/xla/client/local_client.h create mode 100644 tensorflow/compiler/xla/client/padding.cc create mode 100644 tensorflow/compiler/xla/client/padding.h create mode 100644 tensorflow/compiler/xla/client/padding_test.cc create mode 100644 tensorflow/compiler/xla/device_util.h create mode 100644 tensorflow/compiler/xla/differential_set.h create mode 100644 tensorflow/compiler/xla/differential_set_test.cc create mode 100644 tensorflow/compiler/xla/executable_run_options.cc create mode 100644 tensorflow/compiler/xla/executable_run_options.h create mode 100644 tensorflow/compiler/xla/index_util.cc create mode 100644 tensorflow/compiler/xla/index_util.h create mode 100644 tensorflow/compiler/xla/index_util_test.cc create mode 100644 tensorflow/compiler/xla/layout_util.cc create mode 100644 tensorflow/compiler/xla/layout_util.h create mode 100644 tensorflow/compiler/xla/layout_util_test.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/BUILD create mode 100644 tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/backend_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/backend_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/layout_util_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h create mode 100644 tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/service_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/service_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h create mode 100644 tensorflow/compiler/xla/legacy_flags/util_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/util_flags.h create mode 100644 tensorflow/compiler/xla/literal_util.cc create mode 100644 tensorflow/compiler/xla/literal_util.h create mode 100644 tensorflow/compiler/xla/literal_util_test.cc create mode 100644 tensorflow/compiler/xla/map_util.h create mode 100644 tensorflow/compiler/xla/packed_literal_reader.cc create mode 100644 tensorflow/compiler/xla/packed_literal_reader.h create mode 100644 tensorflow/compiler/xla/port/BUILD create mode 100644 tensorflow/compiler/xla/port/initialize.h create mode 100644 tensorflow/compiler/xla/primitive_util.cc create mode 100644 tensorflow/compiler/xla/primitive_util.h create mode 100644 tensorflow/compiler/xla/protobuf_util.cc create mode 100644 tensorflow/compiler/xla/protobuf_util.h create mode 100644 tensorflow/compiler/xla/ptr_util.h create mode 100644 tensorflow/compiler/xla/reference_util.cc create mode 100644 tensorflow/compiler/xla/reference_util.h create mode 100644 tensorflow/compiler/xla/reference_util_test.cc create mode 100644 tensorflow/compiler/xla/service/BUILD create mode 100644 tensorflow/compiler/xla/service/algebraic_simplifier.cc create mode 100644 tensorflow/compiler/xla/service/algebraic_simplifier.h create mode 100644 tensorflow/compiler/xla/service/algebraic_simplifier_test.cc create mode 100644 tensorflow/compiler/xla/service/allocation_tracker.cc create mode 100644 tensorflow/compiler/xla/service/allocation_tracker.h create mode 100644 tensorflow/compiler/xla/service/backend.cc create mode 100644 tensorflow/compiler/xla/service/backend.h create mode 100644 tensorflow/compiler/xla/service/buffer_assignment.cc create mode 100644 tensorflow/compiler/xla/service/buffer_assignment.h create mode 100644 tensorflow/compiler/xla/service/buffer_assignment_test.cc create mode 100644 tensorflow/compiler/xla/service/buffer_liveness.cc create mode 100644 tensorflow/compiler/xla/service/buffer_liveness.h create mode 100644 tensorflow/compiler/xla/service/buffer_liveness_test.cc create mode 100644 tensorflow/compiler/xla/service/channel_tracker.cc create mode 100644 tensorflow/compiler/xla/service/channel_tracker.h create mode 100644 tensorflow/compiler/xla/service/compilation_cache.cc create mode 100644 tensorflow/compiler/xla/service/compilation_cache.h create mode 100644 tensorflow/compiler/xla/service/compiler.cc create mode 100644 tensorflow/compiler/xla/service/compiler.h create mode 100644 tensorflow/compiler/xla/service/computation_layout.cc create mode 100644 tensorflow/compiler/xla/service/computation_layout.h create mode 100644 tensorflow/compiler/xla/service/computation_tracker.cc create mode 100644 tensorflow/compiler/xla/service/computation_tracker.h create mode 100644 tensorflow/compiler/xla/service/copy_insertion.cc create mode 100644 tensorflow/compiler/xla/service/copy_insertion.h create mode 100644 tensorflow/compiler/xla/service/copy_insertion_test.cc create mode 100644 tensorflow/compiler/xla/service/cpu/BUILD create mode 100644 tensorflow/compiler/xla/service/cpu/build_defs.bzl create mode 100644 tensorflow/compiler/xla/service/cpu/compiler_functor.cc create mode 100644 tensorflow/compiler/xla/service/cpu/compiler_functor.h create mode 100644 tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc create mode 100644 tensorflow/compiler/xla/service/cpu/conv_canonicalization.h create mode 100644 tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_compiler.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_compiler.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_executable.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_executable.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc create mode 100644 tensorflow/compiler/xla/service/cpu/disassembler.cc create mode 100644 tensorflow/compiler/xla/service/cpu/disassembler.h create mode 100644 tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc create mode 100644 tensorflow/compiler/xla/service/cpu/dot_op_emitter.h create mode 100644 tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h create mode 100644 tensorflow/compiler/xla/service/cpu/infeed_manager.cc create mode 100644 tensorflow/compiler/xla/service/cpu/infeed_manager.h create mode 100644 tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc create mode 100644 tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc create mode 100644 tensorflow/compiler/xla/service/cpu/ir_emission_utils.h create mode 100644 tensorflow/compiler/xla/service/cpu/ir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/cpu/ir_emitter.h create mode 100644 tensorflow/compiler/xla/service/cpu/layout_assignment.cc create mode 100644 tensorflow/compiler/xla/service/cpu/layout_assignment.h create mode 100644 tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc create mode 100644 tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_conv2d.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_matmul.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h create mode 100644 tensorflow/compiler/xla/service/cpu/sample_harness.cc create mode 100644 tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc create mode 100644 tensorflow/compiler/xla/service/cpu/simple_orc_jit.h create mode 100644 tensorflow/compiler/xla/service/cpu_transfer_manager.cc create mode 100644 tensorflow/compiler/xla/service/cpu_transfer_manager.h create mode 100644 tensorflow/compiler/xla/service/device_memory_allocator.cc create mode 100644 tensorflow/compiler/xla/service/device_memory_allocator.h create mode 100644 tensorflow/compiler/xla/service/dfs_hlo_visitor.cc create mode 100644 tensorflow/compiler/xla/service/dfs_hlo_visitor.h create mode 100644 tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h create mode 100644 tensorflow/compiler/xla/service/elemental_ir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/elemental_ir_emitter.h create mode 100644 tensorflow/compiler/xla/service/executable.cc create mode 100644 tensorflow/compiler/xla/service/executable.h create mode 100644 tensorflow/compiler/xla/service/execution_tracker.cc create mode 100644 tensorflow/compiler/xla/service/execution_tracker.h create mode 100644 tensorflow/compiler/xla/service/generic_transfer_manager.cc create mode 100644 tensorflow/compiler/xla/service/generic_transfer_manager.h create mode 100644 tensorflow/compiler/xla/service/gpu/BUILD create mode 100644 tensorflow/compiler/xla/service/gpu/buffer_allocations.cc create mode 100644 tensorflow/compiler/xla/service/gpu/buffer_allocations.h create mode 100644 tensorflow/compiler/xla/service/gpu/convolution_folding.cc create mode 100644 tensorflow/compiler/xla/service/gpu/convolution_folding.h create mode 100644 tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/convolution_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/convolution_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/copy_insertion.cc create mode 100644 tensorflow/compiler/xla/service/gpu/copy_insertion.h create mode 100644 tensorflow/compiler/xla/service/gpu/copy_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/copy_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h create mode 100644 tensorflow/compiler/xla/service/gpu/for_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/for_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/gemm_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gemm_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_compiler.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_compiler.h create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_executable.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_executable.h create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_schedule.cc create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_schedule.h create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h create mode 100644 tensorflow/compiler/xla/service/gpu/instruction_fusion.cc create mode 100644 tensorflow/compiler/xla/service/gpu/instruction_fusion.h create mode 100644 tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emission_utils.h create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emitter.h create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emitter_context.h create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc create mode 100644 tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc create mode 100644 tensorflow/compiler/xla/service/gpu/kernel_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/kernel_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/layout_assignment.cc create mode 100644 tensorflow/compiler/xla/service/gpu/layout_assignment.h create mode 100644 tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h create mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/pad_insertion.cc create mode 100644 tensorflow/compiler/xla/service/gpu/pad_insertion.h create mode 100644 tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc create mode 100644 tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h create mode 100644 tensorflow/compiler/xla/service/gpu/partition_assignment.cc create mode 100644 tensorflow/compiler/xla/service/gpu/partition_assignment.h create mode 100644 tensorflow/compiler/xla/service/gpu/sequential_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/sequential_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/stream_assignment.cc create mode 100644 tensorflow/compiler/xla/service/gpu/stream_assignment.h create mode 100644 tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc create mode 100644 tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h create mode 100644 tensorflow/compiler/xla/service/gpu/thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/thunk_schedule.cc create mode 100644 tensorflow/compiler/xla/service/gpu/thunk_schedule.h create mode 100644 tensorflow/compiler/xla/service/gpu/tuple_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/tuple_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/while_thunk.cc create mode 100644 tensorflow/compiler/xla/service/gpu/while_thunk.h create mode 100644 tensorflow/compiler/xla/service/gpu/while_transformer.cc create mode 100644 tensorflow/compiler/xla/service/gpu/while_transformer.h create mode 100644 tensorflow/compiler/xla/service/gpu/while_transformer_test.cc create mode 100644 tensorflow/compiler/xla/service/graphviz_example.cc create mode 100644 tensorflow/compiler/xla/service/hlo_computation.cc create mode 100644 tensorflow/compiler/xla/service/hlo_computation.h create mode 100644 tensorflow/compiler/xla/service/hlo_computation_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_cost_analysis.cc create mode 100644 tensorflow/compiler/xla/service/hlo_cost_analysis.h create mode 100644 tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_cse.cc create mode 100644 tensorflow/compiler/xla/service/hlo_cse.h create mode 100644 tensorflow/compiler/xla/service/hlo_cse_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_dce.cc create mode 100644 tensorflow/compiler/xla/service/hlo_dce.h create mode 100644 tensorflow/compiler/xla/service/hlo_dce_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_execution_profile.cc create mode 100644 tensorflow/compiler/xla/service/hlo_execution_profile.h create mode 100644 tensorflow/compiler/xla/service/hlo_graph_dumper.cc create mode 100644 tensorflow/compiler/xla/service/hlo_graph_dumper.h create mode 100644 tensorflow/compiler/xla/service/hlo_instruction.cc create mode 100644 tensorflow/compiler/xla/service/hlo_instruction.h create mode 100644 tensorflow/compiler/xla/service/hlo_instruction_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_module.cc create mode 100644 tensorflow/compiler/xla/service/hlo_module.h create mode 100644 tensorflow/compiler/xla/service/hlo_module_config.cc create mode 100644 tensorflow/compiler/xla/service/hlo_module_config.h create mode 100644 tensorflow/compiler/xla/service/hlo_module_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_opcode.cc create mode 100644 tensorflow/compiler/xla/service/hlo_opcode.h create mode 100644 tensorflow/compiler/xla/service/hlo_opcode_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_pass.h create mode 100644 tensorflow/compiler/xla/service/hlo_pass_pipeline.cc create mode 100644 tensorflow/compiler/xla/service/hlo_pass_pipeline.h create mode 100644 tensorflow/compiler/xla/service/hlo_query.cc create mode 100644 tensorflow/compiler/xla/service/hlo_query.h create mode 100644 tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc create mode 100644 tensorflow/compiler/xla/service/hlo_subcomputation_unification.h create mode 100644 tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc create mode 100644 tensorflow/compiler/xla/service/inliner.cc create mode 100644 tensorflow/compiler/xla/service/inliner.h create mode 100644 tensorflow/compiler/xla/service/inliner_test.cc create mode 100644 tensorflow/compiler/xla/service/instruction_fusion.cc create mode 100644 tensorflow/compiler/xla/service/instruction_fusion.h create mode 100644 tensorflow/compiler/xla/service/instruction_fusion_test.cc create mode 100644 tensorflow/compiler/xla/service/layout_assignment.cc create mode 100644 tensorflow/compiler/xla/service/layout_assignment.h create mode 100644 tensorflow/compiler/xla/service/layout_assignment_test.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/BUILD create mode 100644 tensorflow/compiler/xla/service/llvm_ir/README.md create mode 100644 tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h create mode 100644 tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h create mode 100644 tensorflow/compiler/xla/service/llvm_ir/ir_array.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/ir_array.h create mode 100644 tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h create mode 100644 tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/llvm_util.h create mode 100644 tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h create mode 100644 tensorflow/compiler/xla/service/llvm_ir/ops.cc create mode 100644 tensorflow/compiler/xla/service/llvm_ir/ops.h create mode 100644 tensorflow/compiler/xla/service/local_service.cc create mode 100644 tensorflow/compiler/xla/service/local_service.h create mode 100644 tensorflow/compiler/xla/service/logical_buffer.cc create mode 100644 tensorflow/compiler/xla/service/logical_buffer.h create mode 100644 tensorflow/compiler/xla/service/name_uniquer.cc create mode 100644 tensorflow/compiler/xla/service/name_uniquer.h create mode 100644 tensorflow/compiler/xla/service/platform_util.cc create mode 100644 tensorflow/compiler/xla/service/platform_util.h create mode 100644 tensorflow/compiler/xla/service/reshape_mover.cc create mode 100644 tensorflow/compiler/xla/service/reshape_mover.h create mode 100644 tensorflow/compiler/xla/service/reshape_mover_test.cc create mode 100644 tensorflow/compiler/xla/service/service.cc create mode 100644 tensorflow/compiler/xla/service/service.h create mode 100644 tensorflow/compiler/xla/service/session.proto create mode 100644 tensorflow/compiler/xla/service/shape_inference.cc create mode 100644 tensorflow/compiler/xla/service/shape_inference.h create mode 100644 tensorflow/compiler/xla/service/shape_inference_test.cc create mode 100644 tensorflow/compiler/xla/service/shaped_buffer.cc create mode 100644 tensorflow/compiler/xla/service/shaped_buffer.h create mode 100644 tensorflow/compiler/xla/service/transfer_manager.cc create mode 100644 tensorflow/compiler/xla/service/transfer_manager.h create mode 100644 tensorflow/compiler/xla/service/transfer_manager_test.cc create mode 100644 tensorflow/compiler/xla/service/transpose_folding.cc create mode 100644 tensorflow/compiler/xla/service/transpose_folding.h create mode 100644 tensorflow/compiler/xla/service/transpose_folding_test.cc create mode 100644 tensorflow/compiler/xla/service/tuple_points_to_analysis.cc create mode 100644 tensorflow/compiler/xla/service/tuple_points_to_analysis.h create mode 100644 tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc create mode 100644 tensorflow/compiler/xla/service/user_computation.cc create mode 100644 tensorflow/compiler/xla/service/user_computation.h create mode 100644 tensorflow/compiler/xla/service/versioned_computation_handle.h create mode 100644 tensorflow/compiler/xla/service_interface.h create mode 100644 tensorflow/compiler/xla/shape_layout.cc create mode 100644 tensorflow/compiler/xla/shape_layout.h create mode 100644 tensorflow/compiler/xla/shape_tree.h create mode 100644 tensorflow/compiler/xla/shape_tree_test.cc create mode 100644 tensorflow/compiler/xla/shape_util.cc create mode 100644 tensorflow/compiler/xla/shape_util.h create mode 100644 tensorflow/compiler/xla/shape_util_test.cc create mode 100644 tensorflow/compiler/xla/status.h create mode 100644 tensorflow/compiler/xla/status_macros.cc create mode 100644 tensorflow/compiler/xla/status_macros.h create mode 100644 tensorflow/compiler/xla/status_macros_test.cc create mode 100644 tensorflow/compiler/xla/statusor.cc create mode 100644 tensorflow/compiler/xla/statusor.h create mode 100644 tensorflow/compiler/xla/statusor_test.cc create mode 100644 tensorflow/compiler/xla/test_helpers.cc create mode 100644 tensorflow/compiler/xla/test_helpers.h create mode 100644 tensorflow/compiler/xla/tests/BUILD create mode 100644 tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc create mode 100644 tensorflow/compiler/xla/tests/axpy_simple_test.cc create mode 100644 tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc create mode 100644 tensorflow/compiler/xla/tests/batch_normalization_test.cc create mode 100644 tensorflow/compiler/xla/tests/binop_scaling_test.cc create mode 100644 tensorflow/compiler/xla/tests/broadcast_simple_test.cc create mode 100644 tensorflow/compiler/xla/tests/broadcast_test.cc create mode 100644 tensorflow/compiler/xla/tests/build_defs.bzl create mode 100644 tensorflow/compiler/xla/tests/call_test.cc create mode 100644 tensorflow/compiler/xla/tests/check_execution_arity_test.cc create mode 100644 tensorflow/compiler/xla/tests/client_library_test_base.cc create mode 100644 tensorflow/compiler/xla/tests/client_library_test_base.h create mode 100644 tensorflow/compiler/xla/tests/client_test.cc create mode 100644 tensorflow/compiler/xla/tests/codegen_test_base.cc create mode 100644 tensorflow/compiler/xla/tests/codegen_test_base.h create mode 100644 tensorflow/compiler/xla/tests/compilation_cache_test.cc create mode 100644 tensorflow/compiler/xla/tests/compute_constant_test.cc create mode 100644 tensorflow/compiler/xla/tests/concat_test.cc create mode 100644 tensorflow/compiler/xla/tests/constants_test.cc create mode 100644 tensorflow/compiler/xla/tests/convert_test.cc create mode 100644 tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc create mode 100644 tensorflow/compiler/xla/tests/convolution_test.cc create mode 100644 tensorflow/compiler/xla/tests/convolution_variants_test.cc create mode 100644 tensorflow/compiler/xla/tests/copy_test.cc create mode 100644 tensorflow/compiler/xla/tests/custom_call_test.cc create mode 100644 tensorflow/compiler/xla/tests/deallocation_test.cc create mode 100644 tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc create mode 100644 tensorflow/compiler/xla/tests/dot_operation_test.cc create mode 100644 tensorflow/compiler/xla/tests/dynamic_ops_test.cc create mode 100644 tensorflow/compiler/xla/tests/floor_ceil_test.cc create mode 100644 tensorflow/compiler/xla/tests/fmax_test.cc create mode 100644 tensorflow/compiler/xla/tests/fusion_test.cc create mode 100644 tensorflow/compiler/xla/tests/hlo_test_base.cc create mode 100644 tensorflow/compiler/xla/tests/hlo_test_base.h create mode 100644 tensorflow/compiler/xla/tests/inprocess_service_test.cc create mode 100644 tensorflow/compiler/xla/tests/literal_test_util.cc create mode 100644 tensorflow/compiler/xla/tests/literal_test_util.h create mode 100644 tensorflow/compiler/xla/tests/literal_test_util_test.cc create mode 100644 tensorflow/compiler/xla/tests/local_client_aot_test.cc create mode 100644 tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc create mode 100644 tensorflow/compiler/xla/tests/local_client_test_base.cc create mode 100644 tensorflow/compiler/xla/tests/local_client_test_base.h create mode 100644 tensorflow/compiler/xla/tests/log_test.cc create mode 100644 tensorflow/compiler/xla/tests/map_test.cc create mode 100644 tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc create mode 100644 tensorflow/compiler/xla/tests/multidimensional_slice_test.cc create mode 100644 tensorflow/compiler/xla/tests/pad_test.cc create mode 100644 tensorflow/compiler/xla/tests/params_test.cc create mode 100644 tensorflow/compiler/xla/tests/pred_test.cc create mode 100644 tensorflow/compiler/xla/tests/prng_test.cc create mode 100644 tensorflow/compiler/xla/tests/query_inferred_shape_test.cc create mode 100644 tensorflow/compiler/xla/tests/reduce_test.cc create mode 100644 tensorflow/compiler/xla/tests/reduce_window_test.cc create mode 100644 tensorflow/compiler/xla/tests/replay_test.cc create mode 100644 tensorflow/compiler/xla/tests/reshape_motion_test.cc create mode 100644 tensorflow/compiler/xla/tests/reshape_test.cc create mode 100644 tensorflow/compiler/xla/tests/reverse_test.cc create mode 100644 tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc create mode 100644 tensorflow/compiler/xla/tests/round_trip_transfer_test.cc create mode 100644 tensorflow/compiler/xla/tests/scalar_computations_test.cc create mode 100644 tensorflow/compiler/xla/tests/select_and_scatter_test.cc create mode 100644 tensorflow/compiler/xla/tests/select_test.cc create mode 100644 tensorflow/compiler/xla/tests/set_return_value_test.cc create mode 100644 tensorflow/compiler/xla/tests/slice_test.cc create mode 100644 tensorflow/compiler/xla/tests/test_macros.h create mode 100644 tensorflow/compiler/xla/tests/test_utils.h create mode 100644 tensorflow/compiler/xla/tests/transpose_test.cc create mode 100644 tensorflow/compiler/xla/tests/tuple_test.cc create mode 100644 tensorflow/compiler/xla/tests/unary_op_test.cc create mode 100644 tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc create mode 100644 tensorflow/compiler/xla/tests/vector_ops_simple_test.cc create mode 100644 tensorflow/compiler/xla/tests/while_test.cc create mode 100644 tensorflow/compiler/xla/text_literal_reader.cc create mode 100644 tensorflow/compiler/xla/text_literal_reader.h create mode 100644 tensorflow/compiler/xla/text_literal_reader_test.cc create mode 100644 tensorflow/compiler/xla/text_literal_writer.cc create mode 100644 tensorflow/compiler/xla/text_literal_writer.h create mode 100644 tensorflow/compiler/xla/text_literal_writer_test.cc create mode 100644 tensorflow/compiler/xla/tools/BUILD create mode 100644 tensorflow/compiler/xla/tools/convert_computation.cc create mode 100644 tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc create mode 100644 tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc create mode 100644 tensorflow/compiler/xla/tools/dumped_computation_to_text.cc create mode 100644 tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc create mode 100644 tensorflow/compiler/xla/tools/replay_computation.cc create mode 100644 tensorflow/compiler/xla/tools/show_literal.cc create mode 100644 tensorflow/compiler/xla/tools/show_signature.cc create mode 100644 tensorflow/compiler/xla/tools/show_text_literal.cc create mode 100644 tensorflow/compiler/xla/types.h create mode 100644 tensorflow/compiler/xla/util.cc create mode 100644 tensorflow/compiler/xla/util.h create mode 100644 tensorflow/compiler/xla/util_test.cc create mode 100644 tensorflow/compiler/xla/window_util.cc create mode 100644 tensorflow/compiler/xla/window_util.h create mode 100644 tensorflow/compiler/xla/xla.bzl create mode 100644 tensorflow/compiler/xla/xla.proto create mode 100644 tensorflow/compiler/xla/xla_data.proto diff --git a/configure b/configure index 8d4b12aad2..64add33bd5 100755 --- a/configure +++ b/configure @@ -112,6 +112,26 @@ else sed -i -e "s/WITH_HDFS_SUPPORT = True/WITH_HDFS_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl fi +## Enable XLA. +while [ "$TF_ENABLE_XLA" == "" ]; do + read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;; + [Nn]* ) echo "No XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;; + "" ) echo "No XLA support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +if [ "$TF_ENABLE_XLA" == "1" ]; then + # Update Bazel build configuration. + perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl +else + # Update Bazel build configuration. + perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl +fi + + # Invoke python_config and set up symlinks to python includes ./util/python/python_config.sh --setup "$PYTHON_BIN_PATH" diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 44015ff7d9..ef04a6a88e 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -95,6 +95,26 @@ filegroup( "//tensorflow/c:all_files", "//tensorflow/cc:all_files", "//tensorflow/cc/saved_model:all_files", + "//tensorflow/compiler/aot:all_files", + "//tensorflow/compiler/aot/tests:all_files", + "//tensorflow/compiler/jit:all_files", + "//tensorflow/compiler/jit/graphcycles:all_files", + "//tensorflow/compiler/jit/legacy_flags:all_files", + "//tensorflow/compiler/tests:all_files", + "//tensorflow/compiler/tf2xla:all_files", + "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/xla:all_files", + "//tensorflow/compiler/xla/client:all_files", + "//tensorflow/compiler/xla/client/lib:all_files", + "//tensorflow/compiler/xla/legacy_flags:all_files", + "//tensorflow/compiler/xla/port:all_files", + "//tensorflow/compiler/xla/service:all_files", + "//tensorflow/compiler/xla/service/cpu:all_files", + "//tensorflow/compiler/xla/service/gpu:all_files", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files", + "//tensorflow/compiler/xla/service/llvm_ir:all_files", + "//tensorflow/compiler/xla/tests:all_files", + "//tensorflow/compiler/xla/tools:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/android:all_files", "//tensorflow/contrib/bayesflow:all_files", diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD new file mode 100644 index 0000000000..c52a56b642 --- /dev/null +++ b/tensorflow/compiler/aot/BUILD @@ -0,0 +1,218 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//visibility:private"], +) + +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# Optional runtime utilities for use by code generated by tfcompile. +cc_library( + name = "runtime", + srcs = ["runtime.cc"], + hdrs = ["runtime.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + ], +) + +cc_test( + name = "runtime_test", + srcs = ["runtime_test.cc"], + deps = [ + ":runtime", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Don't depend on this directly; this is only used for the benchmark test +# generated by tf_library. +cc_library( + name = "tf_library_test_main", + testonly = 1, + visibility = ["//visibility:public"], + deps = ["//tensorflow/core:test_main"], +) + +xla_proto_library( + name = "tfcompile_proto", + srcs = ["tfcompile.proto"], + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "tfcompile_lib", + srcs = [ + "codegen.cc", + "compile.cc", + "flags.cc", + "tfcompile_util.cc", + ], + hdrs = [ + "codegen.h", + "compile.h", + "flags.h", + "tfcompile_util.h", + ], + deps = [ + ":runtime", # needed by codegen to print aligned_buffer_bytes + ":tfcompile_proto", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_test( + name = "codegen_test", + srcs = ["codegen_test.cc"], + data = ["codegen_test_h.golden"], + deps = [ + ":tfcompile_lib", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "tfcompile_util_test", + srcs = ["tfcompile_util_test.cc"], + deps = [ + ":tfcompile_lib", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_binary( + name = "tfcompile", + visibility = ["//visibility:public"], + deps = [":tfcompile_main"], +) + +cc_library( + name = "tfcompile_main", + srcs = ["tfcompile_main.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tfcompile_lib", + ":tfcompile_proto", + "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that +# tfcompile.bzl correctly handles usage from outside of the package that it is +# defined in. + +# A simple test of tf_library from a text protobuf, mostly to enable the +# benchmark_test. +tf_library( + name = "test_graph_tfadd", + testonly = 1, + config = "test_graph_tfadd.config.pbtxt", + cpp_class = "AddComp", + graph = "test_graph_tfadd.pbtxt", + tags = ["manual"], +) + +# Utility library for benchmark binaries, used by the *_benchmark rules that are +# added by the tfcompile bazel macro. +cc_library( + name = "benchmark", + srcs = ["benchmark.cc"], + hdrs = ["benchmark.h"], + visibility = ["//visibility:public"], + deps = [ + # The purpose of the benchmark library is to support building an aot + # binary with minimal dependencies, to demonstrate small binary sizes. + # + # KEEP THE DEPENDENCIES MINIMAL. + "//tensorflow/core:framework_lite", + ], +) + +cc_library( + name = "benchmark_extra_android", + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], +) + +cc_test( + name = "benchmark_test", + srcs = ["benchmark_test.cc"], + tags = ["manual"], + deps = [ + ":benchmark", + ":test_graph_tfadd", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +test_suite( + name = "all_tests", + tags = ["manual"], + tests = [ + ":benchmark_test", + ":test_graph_tfadd_test", + "//tensorflow/compiler/aot/tests:all_tests", + ], +) + +exports_files([ + "benchmark_main.template", # used by tf_library(...,gen_benchmark=True) + "test.cc", # used by tf_library(...,gen_test=True) +]) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/aot/benchmark.cc b/tensorflow/compiler/aot/benchmark.cc new file mode 100644 index 0000000000..0c5e2c103e --- /dev/null +++ b/tensorflow/compiler/aot/benchmark.cc @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The purpose of the benchmark library is to support building an aot binary +// with minimal dependencies, to demonstrate small binary sizes. +// +// KEEP THE DEPENDENCIES MINIMAL. + +#include "tensorflow/compiler/aot/benchmark.h" + +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tfcompile { +namespace benchmark { + +// Returns current wall time in micros. +// +// TODO(b/33546473): Refactor tensorflow::Env::NowMicros() so that we can re-use +// the implementation without pulling in all of the Env dependencies. +static double NowMicros() { + struct timeval tv; + gettimeofday(&tv, NULL); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +} + +void DumpStatsToStdout(const Stats& stats) { + // Compute stats. + std::vector sorted_us(stats.per_iter_us); + std::sort(sorted_us.begin(), sorted_us.end()); + const size_t count_us = sorted_us.size(); + double sum_us = 0; + size_t count_us_trimmed = 0; + double sum_us_trimmed = 0; + size_t count_us_best = 0; + double sum_us_best = 0; + static constexpr float trim_ratio = 0.25; + static constexpr float best_ratio = 0.1; + const size_t count_trimmed = count_us * trim_ratio; + const size_t count_best = count_us * best_ratio; + for (size_t i = 0; i < sorted_us.size(); ++i) { + const int64 us = sorted_us[i]; + sum_us += us; + if (i >= count_trimmed && i < count_us - count_trimmed) { + sum_us_trimmed += us; + ++count_us_trimmed; + } + if (i < count_best) { + sum_us_best += us; + ++count_us_best; + } + } + // Prepare nicely-formatted data. + const int kBufSize = 1000; + char buf[kBufSize]; + snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100); + const string label_trimmed(buf); + snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100); + const string label_best(buf); + std::vector> groups = { + {"Best:", sorted_us.front()}, + {"Worst:", sorted_us.back()}, + {"Median:", sorted_us[count_us / 2]}, + {"Mean:", sum_us / count_us}, + {label_trimmed, sum_us_trimmed / count_us_trimmed}, + {label_best, sum_us_best / count_us_best}, + }; + int max_label_size = 0; + double max_us = 0; + for (const auto& g : groups) { + if (g.first.size() > max_label_size) { + max_label_size = g.first.size(); + } + if (g.second > max_us) { + max_us = g.second; + } + } + int max_digits = 1; + while (max_us >= 10.0) { + max_us /= 10.0; + ++max_digits; + } + // Dump stats out. + printf("Benchmark ran %zu iterations over %lld us\n", count_us, + stats.total_us); + for (const auto& g : groups) { + printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4, + g.second); + } +} + +void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) { + // If neither max_seconds or max_iters is set, stop at kDefaultMicros. + const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0) + ? Options::kDefaultMicros + : options.max_micros; + printf("Running benchmark for %lld us\n", max_us); + const int64 start_us = NowMicros(); + int64 iters = 0; + while (true) { + const int64 iter_start_us = NowMicros(); + fn(); + const int64 end_us = NowMicros(); + // Collect stats and decide whether to stop. + stats->per_iter_us.push_back(end_us - iter_start_us); + const int64 total_us = end_us - start_us; + ++iters; + if ((max_us > 0 && total_us >= max_us) || + (options.max_iters > 0 && iters >= options.max_iters)) { + stats->total_us = total_us; + break; + } + } +} + +} // namespace benchmark +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/benchmark.h b/tensorflow/compiler/aot/benchmark.h new file mode 100644 index 0000000000..266b7fefc7 --- /dev/null +++ b/tensorflow/compiler/aot/benchmark.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains benchmark functions used with the code-generated benchmarks that can +// be used to test a model on android. See also code generation rules in +// tfcompile.bzl. +// +// This is separate from the built-in micro-benchmarks, because we want to: +// 1. show a binary with minimal dependencies, to show a close-to-lower-bound +// binary size. +// 2. compile on Android. +#ifndef TENSORFLOW_COMPILER_AOT_BENCHMARK_H_ +#define TENSORFLOW_COMPILER_AOT_BENCHMARK_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tfcompile { +namespace benchmark { + +// Options specifies options for benchmarks of functions generated by tfcompile. +struct Options { + // kDefaultMicros specifies the default time to run the benchmark, and is used + // if neither max_iters nor max_micros is set. + static const int64 kDefaultMicros = 3000000; + + int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0. + int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0. +}; + +// Stats holds statistics collected during benchmarking. +struct Stats { + std::vector per_iter_us; // Per-iteration deltas in us. + int64 total_us; // Total time in us. + + Stats() : total_us(0) { per_iter_us.reserve(5000); } +}; + +// DumpStatsToStdout printfs to stdout stats in a multi-line human-friendly +// form. +void DumpStatsToStdout(const Stats& stats); + +// BenchmarkFn is the signature of the function generated by tfcompile. +typedef std::function BenchmarkFn; + +// Benchmark runs a benchmark of the function `fn`, collecting stats in `stats`. +// Use `options` to configure benchmarking options. +void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats); + +} // namespace benchmark +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_BENCHMARK_H_ diff --git a/tensorflow/compiler/aot/benchmark_main.template b/tensorflow/compiler/aot/benchmark_main.template new file mode 100644 index 0000000000..a4df6ed1cf --- /dev/null +++ b/tensorflow/compiler/aot/benchmark_main.template @@ -0,0 +1,51 @@ +// Generated by the tf_library build rule. DO NOT EDIT! +// +// This file contains the main function and logic for benchmarking code +// generated by tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be +// rewritten to real values before this file can be compiled. +// +// TFCOMPILE_HEADER : Path to the header file generated by tfcompile. +// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile. +// +// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and +// generates a cc_binary rule for you. + +// These macros must be defined before eigen files are included. +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +// clang-format off +#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces) +// clang-format on + +#include "tensorflow/compiler/aot/benchmark.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// Macros that expand to tokens based on the entry point name. +// clang-format off +#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces) +// clang-format on + +namespace tensorflow { +namespace tfcompile { + +int Main(int argc, char** argv) { + Eigen::ThreadPool pool(1 /* num_threads */); + Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + + CPP_CLASS computation; + computation.set_thread_pool(&device); + + benchmark::Options options; + benchmark::Stats stats; + benchmark::Benchmark(options, [&] { computation.Run(); }, &stats); + benchmark::DumpStatsToStdout(stats); + return 0; +} + +} // namespace tfcompile +} // namespace tensorflow + +int main(int argc, char** argv) { + return tensorflow::tfcompile::Main(argc, argv); +} diff --git a/tensorflow/compiler/aot/benchmark_test.cc b/tensorflow/compiler/aot/benchmark_test.cc new file mode 100644 index 0000000000..0568c5b1d5 --- /dev/null +++ b/tensorflow/compiler/aot/benchmark_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/benchmark.h" + +#include "tensorflow/compiler/aot/test_graph_tfadd.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tfcompile { +namespace benchmark { +namespace { + +// There isn't much we can verify in a stable fashion, so we just run the +// benchmark with max_iters, and ensure we end up with that many iter stats. +TEST(Benchmark, Benchmark) { + AddComp add; + + Options options; + options.max_iters = 1; + Stats stats1; + Benchmark(options, [&] { add.Run(); }, &stats1); + EXPECT_EQ(stats1.per_iter_us.size(), 1); + + options.max_iters = 5; + Stats stats5; + Benchmark(options, [&] { add.Run(); }, &stats5); + EXPECT_EQ(stats5.per_iter_us.size(), 5); +} + +} // namespace +} // namespace benchmark +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc new file mode 100644 index 0000000000..042a72745a --- /dev/null +++ b/tensorflow/compiler/aot/codegen.cc @@ -0,0 +1,579 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/codegen.h" + +#include +#include +#include + +#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/tf2xla/str_util.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace tfcompile { + +namespace { + +// Convert an XLA type into a C++ type. +Status XLATypeToCpp(xla::PrimitiveType type, string* str) { + switch (type) { + case xla::PRED: + *str = "bool"; + break; + case xla::S8: + *str = "tensorflow::int8"; + break; + case xla::S16: + *str = "tensorflow::int16"; + break; + case xla::S32: + *str = "tensorflow::int32"; + break; + case xla::S64: + *str = "tensorflow::int64"; + break; + case xla::U8: + *str = "tensorflow::uint8"; + break; + case xla::U16: + *str = "tensorflow::uint16"; + break; + case xla::U32: + *str = "tensorflow::uint32"; + break; + case xla::U64: + *str = "tensorflow::uint64"; + break; + case xla::F32: + *str = "float"; + break; + case xla::F64: + *str = "double"; + break; + default: + return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type), + " has no equivalent in C++"); + } + return Status::OK(); +} + +// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1 +// values. There are `n` entries in `sizes`. +size_t total_buffer_bytes(const intptr_t* sizes, size_t n) { + size_t total = 0; + for (size_t i = 0; i < n; ++i) { + if (sizes[i] != -1) { + total += sizes[i]; + } + } + return total; +} + +// Fills in arg_sizes with the byte size of each positional arg. +Status ComputeArgSizes(const CompileResult& compile_result, + std::vector* arg_sizes) { + const xla::ProgramShape& ps = compile_result.program_shape; + for (int i = 0; i < ps.parameters_size(); ++i) { + if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) { + // If the compiled function needs a XlaLocalRuntimeContext* arg, it's + // always last, and must be represented as an opaque type. + const xla::PrimitiveType type = ps.parameters(i).element_type(); + if (type != xla::OPAQUE) { + return errors::InvalidArgument( + "expected final context arg to be opaque, but got type: ", + xla::PrimitiveType_Name(type), ", from program shape: ", + xla::ShapeUtil::HumanString(ps)); + } + arg_sizes->push_back(-1); + } else { + arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( + ps.parameters(i), compile_result.pointer_size)); + } + } + return Status::OK(); +} + +// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs +// are used to generate methods for args and results. +Status AddRewritesForShape(int i, const xla::Shape& shape, + std::vector>* rewrites) { + string type; + TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); + std::vector dim_vars; + string dim_sizes, indices; + if (xla::ShapeUtil::Rank(shape) == 0 || + (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { + dim_sizes = "[1]"; + indices = "[0]"; + } else { + for (int dim = 0; dim < shape.dimensions_size(); ++dim) { + dim_vars.push_back(strings::StrCat("size_t dim", dim)); + dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]"); + indices += strings::StrCat("[dim", dim, "]"); + } + } + rewrites->push_back({"{{I}}", strings::StrCat(i)}); + rewrites->push_back({"{{TYPE}}", type}); + rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); + rewrites->push_back({"{{INDICES}}", indices}); + return Status::OK(); +} + +// Returns code rewritten by replacing all rewrite pairs, with an extra rewrite +// for the name. Note that the rewriting strategy is roughly O(N*M), where N is +// the size of the code and M is the number of rewrites. It's fine for now +// since N and M are pretty small. +// +// TODO(toddw): If this becomes a problem, we should be able to change the +// algorithm to O(N) by using a state machine, e.g. regexps or a real +// text-templating mechanism. +string RewriteWithName(const string& name, string code, + const std::vector>& rewrites) { + str_util::ReplaceAllPairs(&code, rewrites); + str_util::ReplaceAll(&code, "{{NAME}}", name); + return code; +} + +// Generate methods for args (inputs). +Status GenArgMethods(const Config& config, const xla::ProgramShape& ps, + const CompileResult& compile_result, string* methods) { + *methods += R"( + void** args() { return args_; } + const void *const *args() const { return args_; } +)"; + size_t num_args = ps.parameters_size(); + if (compile_result.has_context_arg) { + // If the compiled function needs a XlaLocalRuntimeContext* arg, it's + // always last, and is set in the class constructor. + num_args--; + } + if (config.feed_size() != num_args) { + return errors::InvalidArgument("mismatch between feed_size(", + config.feed_size(), ") and num_args(", + num_args, ")"); + } + for (int i = 0; i < num_args; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); + const string code = R"( + void set_arg{{NAME}}_data(void* data) { + args_[{{I}}] = data; + } + {{TYPE}}* arg{{NAME}}_data() { + return static_cast<{{TYPE}}*>(args_[{{I}}]); + } + {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) { + return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( + args_[{{I}}])){{INDICES}}; + } + const {{TYPE}}* arg{{NAME}}_data() const { + return static_cast(args_[{{I}}]); + } + const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const { + return (*static_cast( + args_[{{I}}])){{INDICES}}; + } +)"; + *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + if (!config.feed(i).name().empty()) { + *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites); + } + } + return Status::OK(); +} + +// Generate methods for results (outputs). +Status GenResultMethods(const Config& config, const xla::ProgramShape& ps, + string* methods) { + if (ps.result().element_type() != xla::TUPLE) { + // Non-tuple (i.e. single-result) case. + if (config.fetch_size() != 1) { + return errors::InvalidArgument( + "non-tuple result implies 1 fetch, but got ", config.fetch_size(), + " fetches"); + } + *methods += R"( + void** results() { return temps_ + kResultIndex; } + const void *const *results() const { return temps_ + kResultIndex; } +)"; + std::vector> rewrites; + TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites)); + const string code = R"( + {{TYPE}}* result{{NAME}}_data() { + return static_cast<{{TYPE}}*>(temps_[kResultIndex]); + } + {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { + return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( + temps_[kResultIndex])){{INDICES}}; + } + const {{TYPE}}* result{{NAME}}_data() const { + return static_cast(temps_[kResultIndex]); + } + const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { + return (*static_cast( + temps_[kResultIndex])){{INDICES}}; + } +)"; + *methods += RewriteWithName("0", code, rewrites); + if (!config.fetch(0).name().empty()) { + *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites); + } + return Status::OK(); + } + // Tuple (i.e. multi-result) case. + if (config.fetch_size() != ps.result().tuple_shapes_size()) { + return errors::InvalidArgument("mismatch between fetch_size(", + config.feed_size(), ") and tuple_size(", + ps.result().tuple_shapes_size(), ")"); + } + *methods += R"( + void** results() { + return static_cast(temps_[kResultIndex]); + } + const void *const *results() const { + return static_cast(temps_[kResultIndex]); + } +)"; + for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); + string code = R"( + {{TYPE}}* result{{NAME}}_data() { + return static_cast<{{TYPE}}*>( + static_cast(temps_[kResultIndex])[{{I}}]); + } + {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { + return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( + static_cast(temps_[kResultIndex])[{{I}}])){{INDICES}}; + } + const {{TYPE}}* result{{NAME}}_data() const { + return static_cast<{{TYPE}}*>( + static_cast(temps_[kResultIndex])[{{I}}]); + } + const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { + return (*static_cast( + static_cast(temps_[kResultIndex])[{{I}}])){{INDICES}}; + } +)"; + *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + if (!config.fetch(i).name().empty()) { + *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites); + } + } + return Status::OK(); +} + +} // namespace + +Status GenerateHeader(const HeaderOpts& opts, const Config& config, + const CompileResult& compile_result, string* header) { + TF_RETURN_IF_ERROR(ValidateConfig(config)); + const int64 result_index = compile_result.aot->result_buffer_index(); + const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); + if (result_index < 0 || result_index > temp_sizes.size()) { + return errors::InvalidArgument("result index: ", result_index, + " is outside the range of temp sizes: [0,", + temp_sizes.size(), ")"); + } + + // Compute sizes and generate methods. + std::vector arg_sizes; + TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes)); + const xla::ProgramShape& ps = compile_result.program_shape; + string methods_arg, methods_result; + TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); + TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); + const std::vector iarg(arg_sizes.begin(), arg_sizes.end()); + const std::vector itemp(temp_sizes.begin(), temp_sizes.end()); + const size_t arg_bytes_aligned = + runtime::aligned_buffer_bytes(iarg.data(), iarg.size()); + const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); + const size_t temp_bytes_aligned = + runtime::aligned_buffer_bytes(itemp.data(), itemp.size()); + const size_t temp_bytes_total = + total_buffer_bytes(itemp.data(), itemp.size()); + + // Create rewrite strings for the optional context arg. + string context_include; + string context_set_arg, context_set_thread_pool, context_member_var; + string run_result = "true"; + string error_msg = "tensorflow::string()"; + if (compile_result.has_context_arg) { + // NOTE: Extra spaces and newlines are used to ensure nice formatting. + context_include = + "#include " + "\"tensorflow/compiler/tf2xla/" + "xla_local_runtime_context.h\"\n"; + context_set_arg = " args_[kNumArgs-1] = &context_;\n"; + context_set_thread_pool = " context_.thread_pool = pool;\n"; + context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n"; + run_result = "!context_.error"; + error_msg = "context_.error_msg"; + } + + // Create rewrite strings for namespace start and end. + string ns_start; + for (const string& n : opts.namespaces) { + ns_start += strings::StrCat("namespace ", n, " {\n"); + } + ns_start += "\n"; + string ns_end("\n"); + for (int i = opts.namespaces.size() - 1; i >= 0; --i) { + const string& n = opts.namespaces[i]; + ns_end += strings::StrCat("} // end namespace ", n, "\n"); + } + + // Use a poor-man's text templating mechanism; first populate the full header + // with placeholder tokens, and then rewrite the tokens with real values. + *header = + R"(// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT! +// +// This header was generated via ahead-of-time compilation of a TensorFlow +// graph. An object file corresponding to this header was also generated. +// This header gives access to the functionality in that object file. +// +// clang-format off + +#ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) +#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) + +{{CONTEXT_INCLUDE}} +#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace Eigen { class ThreadPoolDevice; } + +// (Implementation detail) Entry point to the function in the object file. +extern "C" void {{ENTRY}}( + void* result, xla::ExecutableRunOptions* run_options, + void** args, void** temps); + +{{NS_START}} +// {{CLASS}} represents a computation previously specified in a +// TensorFlow graph, now compiled into executable code. Usage example: +// +// {{CLASS}} computation; +// // ...set args using computation.argN methods +// CHECK(computation.Run()); +// // ...inspect results using computation.resultN methods +// +// The Run method invokes the actual computation, with inputs read from arg +// buffers, and outputs written to result buffers. Each Run call may also use +// a set of temporary buffers for the computation. +// +// By default each instance of this class manages its own arg, result and temp +// buffers. The AllocMode constructor parameter may be used to modify the +// buffer allocation strategy. +// +// Under the default allocation strategy, this class is thread-compatible: +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while +// it is guaranteed that no thread may call a non-const method. +// +// The logical function signature is: +// {{PROGRAM_SHAPE}} +// +// Memory stats: +// arg bytes total: {{ARG_BYTES_TOTAL}} +// arg bytes aligned: {{ARG_BYTES_ALIGNED}} +// temp bytes total: {{TEMP_BYTES_TOTAL}} +// temp bytes aligned: {{TEMP_BYTES_ALIGNED}} +class {{CLASS}} { + public: + // Number of input arguments for the compiled computation. + static constexpr size_t kNumArgs = {{ARG_NUM}}; + + // Byte size of each argument buffer. There are kNumArgs entries. + static const intptr_t* ArgSizes() { + static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}}; + return kArgSizes; + } + + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + // Allocate all buffers - args, results and temps. + ARGS_RESULTS_AND_TEMPS, + + // Only allocate result and temp buffers. + // Use set_argN_data to set argument buffers before Run is called. + RESULTS_AND_TEMPS_ONLY, + }; + + {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); + } +{{CONTEXT_SET_ARG}} + alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); + } + + ~{{CLASS}}() { + tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); + tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + } + + // Sets the thread pool to use during the Run call. + {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + run_options_.set_intra_op_thread_pool(pool); +{{CONTEXT_SET_THREAD_POOL}} + return *this; + } + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run() { + {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_); + return {{RUN_RESULT}}; + } + + // Returns the error message from the previous failed Run call. + tensorflow::string error_msg() const { return {{ERROR_MSG}}; } + + // Arg methods for managing input buffers. Buffers are in row-major order. + // There is a set of methods for each positional argument, with the following + // general form: + // + // void set_argN_data(void* data) + // Sets the buffer of type T for positional argument N. May be called in + // any AllocMode. Must be called before Run to have an affect. Must be + // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, + // to set the argument buffers. + // + // T* argN_data() + // Returns the buffer of type T for positional argument N. + // + // T& argN(...dim indices...) + // Returns a reference to the value of type T for positional argument N, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + // + // void** args() + // Returns an array of argument buffers, where args()[N] is the buffer for + // positional argument N. +{{METHODS_ARG}} + + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. There is a set of methods + // for each positional result, with the following general form: + // + // T* resultN_data() + // Returns the buffer of type T for positional result N. + // + // T& resultN(...dim indices...) + // Returns a reference to the value of type T for positional result N, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + // + // void** results() + // Returns an array of result buffers, where results()[N] is the buffer for + // positional result N. + // + // Unlike the arg methods, there is no set_resultN_data method. The result + // buffers are managed internally, and may change after each call to Run. +{{METHODS_RESULT}} + + private: + // Number of result and temporary buffers for the compiled computation. + static constexpr size_t kNumTemps = {{TEMP_NUM}}; + // The 0-based index of the result in the temporary buffers. + static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + + // Byte size of each result / temporary buffer. There are kNumTemps entries. + static const intptr_t* TempSizes() { + static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}}; + return kTempSizes; + } + + void* args_[kNumArgs]; + void* temps_[kNumTemps]; + void* alloc_args_ = nullptr; + void* alloc_temps_ = nullptr; + xla::ExecutableRunOptions run_options_; +{{CONTEXT_MEMBER_VAR}} + + TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}}); +}; +{{NS_END}} + +#endif // TFCOMPILE_GENERATED_{{ENTRY}}_H_ + +// clang-format on +)"; + // The replacement strategy is naive, but good enough for our purposes. + const std::vector> rewrites = { + {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, + {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, + {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{CLASS}}", opts.class_name}, + {"{{CONTEXT_INCLUDE}}\n", context_include}, + {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var}, + {"{{CONTEXT_SET_ARG}}\n", context_set_arg}, + {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool}, + {"{{ENTRY}}", compile_result.entry_point}, + {"{{ERROR_MSG}}", error_msg}, + {"{{METHODS_ARG}}\n", methods_arg}, + {"{{METHODS_RESULT}}\n", methods_result}, + {"{{NS_END}}\n", ns_end}, + {"{{NS_START}}\n", ns_start}, + {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, + {"{{RUN_RESULT}}", run_result}, + {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, + {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, + {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, + {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}, + }; + str_util::ReplaceAllPairs(header, rewrites); + return Status::OK(); +} + +Status ParseCppClass(const string& cpp_class, string* class_name, + std::vector* namespaces) { + class_name->clear(); + namespaces->clear(); + size_t begin = 0; + size_t end = 0; + while ((end = cpp_class.find("::", begin)) != string::npos) { + const string ns = cpp_class.substr(begin, end - begin); + TF_RETURN_IF_ERROR(ValidateCppIdent( + ns, "in namespace component of cpp_class: " + cpp_class)); + namespaces->push_back(ns); + begin = end + 2; // +2 to skip the two colons + } + const string name = cpp_class.substr(begin); + TF_RETURN_IF_ERROR( + ValidateCppIdent(name, "in class name of cpp_class: " + cpp_class)); + *class_name = name; + return Status::OK(); +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h new file mode 100644 index 0000000000..7217c57739 --- /dev/null +++ b/tensorflow/compiler/aot/codegen.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_CODEGEN_H_ +#define TENSORFLOW_COMPILER_AOT_CODEGEN_H_ + +#include +#include + +#include "tensorflow/compiler/aot/compile.h" + +namespace tensorflow { +namespace tfcompile { + +// HeaderOpts specifies options for header-file generation. +struct HeaderOpts { + // The name of the generated C++ class, wrapping the generated function. + string class_name; + + // Namespaces specifies a list of C++ namespaces to add to the generated + // header. If empty, all symbols will be in the global namespace. + std::vector namespaces; +}; + +// GenerateHeader uses the meta-information from compile_result to generate a +// C++ header giving access to the function in the generated object file. The +// header includes API usage documentation. +Status GenerateHeader(const HeaderOpts& opts, const Config& config, + const CompileResult& compile_result, string* header); + +// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` +// components. The syntax is [[::],...]. This +// mirrors the C++ syntax for referring to a class, where multiple namespaces +// may precede the class name, separated by double-colons. +Status ParseCppClass(const string& cpp_class, string* class_name, + std::vector* namespaces); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_CODEGEN_H_ diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc new file mode 100644 index 0000000000..e3f76f3666 --- /dev/null +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/codegen.h" + +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tfcompile { +namespace { + +class ParseCppClassTest : public ::testing::Test { + protected: + void ExpectOK(const string& cpp_class, const string& want_class_name, + const std::vector& want_namespaces) { + string class_name; + std::vector namespaces; + TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces)); + EXPECT_EQ(class_name, want_class_name); + EXPECT_EQ(namespaces, want_namespaces); + } + + void ExpectFail(const string& cpp_class) { + string class_name; + std::vector namespaces; + EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK()); + } +}; + +TEST_F(ParseCppClassTest, ParseOK) { + ExpectOK("MyClass", "MyClass", {}); + ExpectOK("_MyClass", "_MyClass", {}); + ExpectOK("a::MyClass", "MyClass", {"a"}); + ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"}); + ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"}); + ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"}); + ExpectOK("foo::MyClass", "MyClass", {"foo"}); + ExpectOK("_foo::MyClass", "MyClass", {"_foo"}); + ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"}); + // Make sure we didn't skip a valid letter or digit + string ident; + for (char c = 'a'; c <= 'z'; c++) { + ident.append(1, c); + } + for (char c = 'A'; c <= 'Z'; c++) { + ident.append(1, c); + } + for (char c = '0'; c <= '9'; c++) { + ident.append(1, c); + } + ident += "_"; + ExpectOK(ident, ident, {}); + ExpectOK(ident + "::" + ident, ident, {ident}); + ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident}); +} + +TEST_F(ParseCppClassTest, ParseFail) { + ExpectFail(""); + ExpectFail("::"); + ExpectFail("::MyClass"); // valid C++, but disallowed for simpler code. + ExpectFail("0"); + ExpectFail("a.b"); + ExpectFail("a:b"); + ExpectFail("good::.bad"); + ExpectFail("good:::bad"); + ExpectFail("good:: bad"); + ExpectFail("good::0bad"); +} + +TEST(GenerateHeader, Golden) { + HeaderOpts opts; + opts.class_name = "MyClass"; + opts.namespaces = {"foo", "bar"}; + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("feed0"); + feed->set_name("myfeed"); + feed = config.add_feed(); + feed->mutable_id()->set_node_name("feed1"); + Fetch* fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("fetch0"); + fetch->set_name("myfetch"); + CompileResult compile_result; + compile_result.aot.reset( + new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5)); + compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( + { + xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), + xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + xla::ShapeUtil::MakeOpaqueShape(), + }, + xla::ShapeUtil::MakeShape(xla::U32, {5, 6})); + compile_result.has_context_arg = true; + compile_result.entry_point = "entry_point"; + compile_result.pointer_size = 8; + string header; + TF_EXPECT_OK(GenerateHeader(opts, config, compile_result, &header)); + + // Compare against the golden file. + const string golden_name = io::JoinPath(testing::TensorFlowSrcRoot(), + "compiler/aot/codegen_test_h.golden"); + // To update the golden file, flip update_golden to true and run the + // following: + // bazel test --test_strategy=local \ + // third_party/tensorflow/compiler/aot:codegen_test + const bool update_golden = false; + if (update_golden) { + TF_EXPECT_OK(WriteStringToFile(Env::Default(), golden_name, header)); + } + string golden_data; + TF_EXPECT_OK(ReadFileToString(Env::Default(), golden_name, &golden_data)); + EXPECT_EQ(header, golden_data); +} + +} // namespace +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden new file mode 100644 index 0000000000..46d7c03006 --- /dev/null +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -0,0 +1,268 @@ +// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT! +// +// This header was generated via ahead-of-time compilation of a TensorFlow +// graph. An object file corresponding to this header was also generated. +// This header gives access to the functionality in that object file. +// +// clang-format off + +#ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) +#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) + +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace Eigen { class ThreadPoolDevice; } + +// (Implementation detail) Entry point to the function in the object file. +extern "C" void entry_point( + void* result, xla::ExecutableRunOptions* run_options, + void** args, void** temps); + +namespace foo { +namespace bar { + +// MyClass represents a computation previously specified in a +// TensorFlow graph, now compiled into executable code. Usage example: +// +// MyClass computation; +// // ...set args using computation.argN methods +// CHECK(computation.Run()); +// // ...inspect results using computation.resultN methods +// +// The Run method invokes the actual computation, with inputs read from arg +// buffers, and outputs written to result buffers. Each Run call may also use +// a set of temporary buffers for the computation. +// +// By default each instance of this class manages its own arg, result and temp +// buffers. The AllocMode constructor parameter may be used to modify the +// buffer allocation strategy. +// +// Under the default allocation strategy, this class is thread-compatible: +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while +// it is guaranteed that no thread may call a non-const method. +// +// The logical function signature is: +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6] +// +// Memory stats: +// arg bytes total: 104 +// arg bytes aligned: 128 +// temp bytes total: 126 +// temp bytes aligned: 224 +class MyClass { + public: + // Number of input arguments for the compiled computation. + static constexpr size_t kNumArgs = 3; + + // Byte size of each argument buffer. There are kNumArgs entries. + static const intptr_t* ArgSizes() { + static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1}; + return kArgSizes; + } + + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + // Allocate all buffers - args, results and temps. + ARGS_RESULTS_AND_TEMPS, + + // Only allocate result and temp buffers. + // Use set_argN_data to set argument buffers before Run is called. + RESULTS_AND_TEMPS_ONLY, + }; + + MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); + } + args_[kNumArgs-1] = &context_; + alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); + } + + ~MyClass() { + tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); + tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + } + + // Sets the thread pool to use during the Run call. + MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + run_options_.set_intra_op_thread_pool(pool); + context_.thread_pool = pool; + return *this; + } + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run() { + entry_point(temps_[kResultIndex], &run_options_, args_, temps_); + return !context_.error; + } + + // Returns the error message from the previous failed Run call. + tensorflow::string error_msg() const { return context_.error_msg; } + + // Arg methods for managing input buffers. Buffers are in row-major order. + // There is a set of methods for each positional argument, with the following + // general form: + // + // void set_argN_data(void* data) + // Sets the buffer of type T for positional argument N. May be called in + // any AllocMode. Must be called before Run to have an affect. Must be + // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, + // to set the argument buffers. + // + // T* argN_data() + // Returns the buffer of type T for positional argument N. + // + // T& argN(...dim indices...) + // Returns a reference to the value of type T for positional argument N, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + // + // void** args() + // Returns an array of argument buffers, where args()[N] is the buffer for + // positional argument N. + + void** args() { return args_; } + const void *const *args() const { return args_; } + + void set_arg0_data(void* data) { + args_[0] = data; + } + float* arg0_data() { + return static_cast(args_[0]); + } + float& arg0(size_t dim0, size_t dim1) { + return (*static_cast( + args_[0]))[dim0][dim1]; + } + const float* arg0_data() const { + return static_cast(args_[0]); + } + const float& arg0(size_t dim0, size_t dim1) const { + return (*static_cast( + args_[0]))[dim0][dim1]; + } + + void set_arg_myfeed_data(void* data) { + args_[0] = data; + } + float* arg_myfeed_data() { + return static_cast(args_[0]); + } + float& arg_myfeed(size_t dim0, size_t dim1) { + return (*static_cast( + args_[0]))[dim0][dim1]; + } + const float* arg_myfeed_data() const { + return static_cast(args_[0]); + } + const float& arg_myfeed(size_t dim0, size_t dim1) const { + return (*static_cast( + args_[0]))[dim0][dim1]; + } + + void set_arg1_data(void* data) { + args_[1] = data; + } + tensorflow::int64* arg1_data() { + return static_cast(args_[1]); + } + tensorflow::int64& arg1(size_t dim0, size_t dim1) { + return (*static_cast( + args_[1]))[dim0][dim1]; + } + const tensorflow::int64* arg1_data() const { + return static_cast(args_[1]); + } + const tensorflow::int64& arg1(size_t dim0, size_t dim1) const { + return (*static_cast( + args_[1]))[dim0][dim1]; + } + + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. There is a set of methods + // for each positional result, with the following general form: + // + // T* resultN_data() + // Returns the buffer of type T for positional result N. + // + // T& resultN(...dim indices...) + // Returns a reference to the value of type T for positional result N, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + // + // void** results() + // Returns an array of result buffers, where results()[N] is the buffer for + // positional result N. + // + // Unlike the arg methods, there is no set_resultN_data method. The result + // buffers are managed internally, and may change after each call to Run. + + void** results() { return temps_ + kResultIndex; } + const void *const *results() const { return temps_ + kResultIndex; } + + tensorflow::uint32* result0_data() { + return static_cast(temps_[kResultIndex]); + } + tensorflow::uint32& result0(size_t dim0, size_t dim1) { + return (*static_cast( + temps_[kResultIndex]))[dim0][dim1]; + } + const tensorflow::uint32* result0_data() const { + return static_cast(temps_[kResultIndex]); + } + const tensorflow::uint32& result0(size_t dim0, size_t dim1) const { + return (*static_cast( + temps_[kResultIndex]))[dim0][dim1]; + } + + tensorflow::uint32* result_myfetch_data() { + return static_cast(temps_[kResultIndex]); + } + tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) { + return (*static_cast( + temps_[kResultIndex]))[dim0][dim1]; + } + const tensorflow::uint32* result_myfetch_data() const { + return static_cast(temps_[kResultIndex]); + } + const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const { + return (*static_cast( + temps_[kResultIndex]))[dim0][dim1]; + } + + private: + // Number of result and temporary buffers for the compiled computation. + static constexpr size_t kNumTemps = 6; + // The 0-based index of the result in the temporary buffers. + static constexpr size_t kResultIndex = 5; + + // Byte size of each result / temporary buffer. There are kNumTemps entries. + static const intptr_t* TempSizes() { + static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120}; + return kTempSizes; + } + + void* args_[kNumArgs]; + void* temps_[kNumTemps]; + void* alloc_args_ = nullptr; + void* alloc_temps_ = nullptr; + xla::ExecutableRunOptions run_options_; + tensorflow::XlaLocalRuntimeContext context_; + + TF_DISALLOW_COPY_AND_ASSIGN(MyClass); +}; + +} // end namespace bar +} // end namespace foo + +#endif // TFCOMPILE_GENERATED_entry_point_H_ + +// clang-format on diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc new file mode 100644 index 0000000000..50e596786a --- /dev/null +++ b/tensorflow/compiler/aot/compile.cc @@ -0,0 +1,416 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/compile.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace tfcompile { + +const char* const kArgOp = "_Arg"; +const char* const kRetvalOp = "_Retval"; +const char* const kFeedIdAttr = "_feed_id"; +const char* const kFetchIdAttr = "_fetch_id"; +const char* const kShapeAttr = "_shape"; +const char* const kDebugNameAttr = "_debug_name"; + +namespace { + +Status DumpGraph(const MainFlags& flags, const string& name, + const Graph& graph) { + if (flags.debug_dir.empty()) { + return Status::OK(); + } + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + string file = io::JoinPath(flags.debug_dir, name + ".pbtxt"); + return WriteTextProto(Env::Default(), file, graph_def); +} + +string TensorIdToString(const TensorId& id) { + return strings::StrCat(id.node_name(), ":", id.output_index()); +} + +typedef std::unordered_map NodeMap; + +// Each feed id identifies the positional output of some node, which may consist +// of multiple edges. For each feed node, replaces all matching edges so that +// they point from a new _Arg node instead. +Status AddArgNodes(Graph* graph, const NodeMap& node_map, + const protobuf::RepeatedPtrField& feeds) { + for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { + const Feed& feed = feeds[arg_index]; + const TensorId& id = feed.id(); + auto it = node_map.find(id.node_name()); + if (it == node_map.end()) { + return errors::NotFound("Can't find feed id: ", TensorIdToString(id)); + } + const Node* feed_node = it->second; + if (id.output_index() >= feed_node->num_outputs()) { + return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id), + ", output index should be < ", + feed_node->num_outputs()); + } + // TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr + // if we can determine it. That way the graph will be initialized with + // whatever shapes we can infer, while the user can still explicitly specify + // or override them. + Node* arg_node = nullptr; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) + .Attr("T", BaseType(feed_node->output_type(id.output_index()))) + .Attr("index", arg_index) + .Attr(kFeedIdAttr, TensorIdToString(id)) + .Attr(kShapeAttr, TensorShape(feed.shape())) + .Attr(kDebugNameAttr, feed.name()) + .Finalize(graph, &arg_node)); + // Collects out-edges from the feed node that have a matching edge index; + // these will be replaced with edges from the arg node instead. Also + // replaces all control edges from Placeholder feed nodes; similar code + // exists in subgraph::RewriteGraphForExecution. + // TODO(toddw): Why only replace control edges from Placeholder? + // + // We must collect the edges first and process them in a second pass, since + // removing the edge from the graph invalidates feed_node->out_edges. + std::vector feed_edges; + for (const Edge* edge : feed_node->out_edges()) { + if (edge->src_output() == id.output_index() || + (edge->src_output() == Graph::kControlSlot && + feed_node->type_string() == "Placeholder")) { + feed_edges.push_back(edge); + } + } + for (const Edge* edge : feed_edges) { + if (edge->src_output() == id.output_index()) { + graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); + } else { + CHECK_EQ(edge->src_output(), Graph::kControlSlot); + graph->AddControlEdge(arg_node, edge->dst()); + } + graph->RemoveEdge(edge); + } + } + return Status::OK(); +} + +// Each fetch id identifies the positional output of some node. For each fetch +// node, adds a new _Retval node instead, and adds the node to `retval_nodes`. +Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, + const protobuf::RepeatedPtrField& fetches, + std::unordered_set* retval_nodes) { + for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) { + const TensorId& id = fetches[ret_index].id(); + auto it = node_map.find(id.node_name()); + if (it == node_map.end()) { + return errors::NotFound("Can't find fetch id: ", TensorIdToString(id)); + } + Node* fetch_node = it->second; + if (id.output_index() >= fetch_node->num_outputs()) { + return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id), + ", output index should be < ", + fetch_node->num_outputs()); + } + // Connects fetch_node -> retval_node. + Node* retval_node = nullptr; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) + .Input(fetch_node, id.output_index()) + .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) + .Attr("index", ret_index) + .Attr(kFetchIdAttr, TensorIdToString(id)) + .Finalize(graph, &retval_node)); + retval_nodes->insert(retval_node); + } + return Status::OK(); +} + +// RewriteAndPruneGraph identifies input and output edges (named by the feed and +// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg +// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph +// execution to know the input and output args for the generated function. +Status RewriteAndPruneGraph(Graph* graph, const Config& config, + const MainFlags& flags) { + NodeMap node_map; + for (Node* n : graph->nodes()) { + node_map[n->name()] = n; + } + TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed())); + std::unordered_set retval_nodes; + TF_RETURN_IF_ERROR( + AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); + TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph)); + PruneForReverseReachability(graph, retval_nodes); + FixupSourceAndSinkEdges(graph); + TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph)); + // Sanity-check, to make sure the feeds and fetches still exist post-pruning. + std::set missing_feeds, missing_fetches; + for (const Feed& feed : config.feed()) { + missing_feeds.insert(TensorIdToString(feed.id())); + } + for (const Fetch& fetch : config.fetch()) { + missing_fetches.insert(TensorIdToString(fetch.id())); + } + for (const Node* n : graph->nodes()) { + if (n->type_string() == kArgOp) { + string feed_id; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id)); + if (missing_feeds.erase(feed_id) == 0) { + return errors::Aborted(kArgOp, " node found with unknown feed id: ", + feed_id); + } + } else if (n->type_string() == kRetvalOp) { + string fetch_id; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id)); + if (missing_fetches.erase(fetch_id) == 0) { + return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ", + fetch_id); + } + } + } + if (!missing_feeds.empty() || !missing_fetches.empty()) { + return errors::Aborted("Post graph-pruning", ", missing feeds: ", + str_util::Join(missing_feeds, ", "), + ", missing fetches: ", + str_util::Join(missing_fetches, ", ")); + } + return Status::OK(); +} + +// CollectArgNodes collects _Arg nodes from the graph, and performs basic +// sanity-checking to ensure the index and type attributes of each node are +// initialized correctly. +Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { + std::map indexed_arg_nodes; + for (Node* n : graph.nodes()) { + if (n->type_string() == kArgOp) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + auto insert_result = indexed_arg_nodes.insert({index, n}); + if (!insert_result.second) { + const Node* dup = insert_result.first->second; + return errors::InvalidArgument( + "Multiple ", kArgOp, " nodes with index ", index, ", ", + n->DebugString(), " and ", dup->DebugString()); + } + } + } + arg_nodes->clear(); + for (const auto& index_node : indexed_arg_nodes) { + if (index_node.first != arg_nodes->size()) { + return errors::InvalidArgument("Expected ", kArgOp, " node with index ", + arg_nodes->size(), ", but got index ", + index_node.first); + } + arg_nodes->push_back(index_node.second); + } + return Status::OK(); +} + +// Fills in xla_args from the corresponding _Arg nodes in the graph. +Status CreateXlaArgs(const Graph& graph, + std::vector* xla_args) { + std::vector arg_nodes; + TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes)); + for (const Node* node : arg_nodes) { + XlaCompiler::Argument arg; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name)); + xla_args->push_back(arg); + } + return Status::OK(); +} + +// Converts the TensorFlow graph into an XLA computation, by executing the +// graph symbolically, with each op building up the XLA HLO. +Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, + const FunctionLibraryDefinition* flib_def, + xla::Computation* computation, bool* has_context_arg) { + // Create a device and context to convert the graph into an XLA computation. + XlaOpRegistry::RegisterJitKernels(); + // Populate the context with args from the graph. + for (Node* node : graph->nodes()) { + node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); + } + std::vector xla_args; + TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); + + // Compile the graph into an XLA computation. + XlaCompiler::Options compiler_options; + compiler_options.client = client; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + compiler_options.allow_cpu_custom_calls = true; + XlaCompiler compiler(compiler_options); + + std::unique_ptr flib_run(NewFunctionLibraryRuntime( + compiler.device_mgr(), Env::Default(), compiler.device(), + graph->versions().producer(), flib_def, OptimizerOptions())); + XlaCompiler::CompilationResult result; + TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph), + flib_run.get(), xla_args, + false /* use_tuple_arg */, &result)); + *has_context_arg = result.requires_runtime_context; + *computation = std::move(result.computation); + + int num_const_results = 0; + for (int i = 0; i < result.outputs.size(); ++i) { + // Ending up with const results (i.e. output args) is an error, since it + // means that one or more fetches that the user specified will be dropped + // from the generated function. It's most likely a configuration error, + // since the user shouldn't be asking for output args that end up as consts. + // + // TODO(toddw): Provide a way for the user to access const output args, + // e.g. perhaps hard-coded into the header, or somehow copied into the + // output buffers. + if (result.outputs[i].is_constant) { + ++num_const_results; + LOG(ERROR) << "ConstRetVal index:" << i + << " value:" << result.outputs[i].constant_value.DebugString(); + } + } + if (num_const_results > 0) { + return errors::Unimplemented( + "Conversion from TensorFlow graph to XLA resulted in ", + num_const_results, + " constant results. The configuration of " + "the output args (i.e. fetch ids) is probably wrong."); + } + if (computation->IsNull()) { + return errors::Aborted( + "Conversion from TensorFlow graph to XLA resulted in an empty " + "computation."); + } + return Status::OK(); +} + +// Compiles the XLA computation into executable code. +Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, + const xla::cpu::CpuAotCompilationOptions& aot_opts, + CompileResult* compile_result) { + // Retrieves arg and result layouts from the computation. + // TODO(toddw): Should we let the user choose the major/minor ordering? + xla::StatusOr> pshape_or = + client->GetComputationShape(computation); + if (!pshape_or.ok()) { + return errors::Unknown("Couldn't get XLA program shape: ", + pshape_or.status().error_message()); + } + compile_result->program_shape = *pshape_or.ValueOrDie(); + xla::ProgramShape* pshape = &compile_result->program_shape; + std::vector arg_layouts; + for (int i = 0; i < pshape->parameters_size(); ++i) { + arg_layouts.push_back(pshape->mutable_parameters(i)); + } + xla::StatusOr> aot_or = + client->CompileAheadOfTime(computation, arg_layouts, pshape->result(), + aot_opts); + if (!aot_or.ok()) { + return errors::Unknown("XLA compilation failed: ", + aot_or.status().error_message()); + } + compile_result->aot = + xla::unique_ptr_static_cast( + aot_or.ConsumeValueOrDie()); + compile_result->entry_point = aot_opts.entry_point_name(); + compile_result->pointer_size = + xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); + return Status::OK(); +} + +} // namespace + +Status InitGraph(const GraphDef& graph_def, const Config& config, + const MainFlags& flags, const FunctionLibraryDefinition* flib, + std::unique_ptr* graph) { + TF_RETURN_IF_ERROR(ValidateConfig(config)); + std::unique_ptr g(new Graph(flib)); + GraphDef copy_def(graph_def); + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(), + 0 /*node_offset*/)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get())); + TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags)); + *graph = std::move(g); + return Status::OK(); +} + +Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, + const FunctionLibraryDefinition* flib, + CompileResult* compile_result) { + // Converts the graph into an XLA computation, and compiles the + // computation. + // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? + namespace gpu = perftools::gputools; + gpu::Platform* cpu_platform = + gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); + xla::LocalClient* client = + xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie(); + xla::Computation computation; + TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib, + &computation, + &compile_result->has_context_arg)); + if (!flags.debug_dir.empty()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + computation.Snapshot()); + string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb"); + TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module)); + } + xla::cpu::CpuAotCompilationOptions aot_opts( + flags.target_triple, flags.target_cpu, flags.target_features, + flags.entry_point, + xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic); + return CompileXla(client, computation, aot_opts, compile_result); +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h new file mode 100644 index 0000000000..8e9c64820b --- /dev/null +++ b/tensorflow/compiler/aot/compile.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_COMPILE_H_ +#define TENSORFLOW_COMPILER_AOT_COMPILE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace tfcompile { + +// Constants for op types and attribute names. +extern const char* const kArgOp; +extern const char* const kRetvalOp; +extern const char* const kFeedIdAttr; +extern const char* const kFetchIdAttr; +extern const char* const kShapeAttr; +extern const char* const kDebugNameAttr; + +// InitGraph creates a graph based on the graph_def, that may then be compiled +// by CompileGraph. +// +// The graph is rewritten with _Arg and _Retval nodes, representing the inputs +// and outputs of the function that will be compiled. Each feed id causes a new +// _Arg node to be created, where we first collect all existing edges pointing +// from the named node's output index, and then rewrite them to point from that +// _Arg node instead. Each fetch id causes a new _Retval node to be created, +// with a new edge pointing from the named node's output index to that _Retval +// node. All _Retval nodes also point to a special CompileExpressions node, +// used internally to finish the compilation. +// +// The rewritten graph is then pruned to only contain the portion necessary to +// compute the outputs. If dump_graphs is true, graph rewrites will be dumped +// for debugging. +Status InitGraph(const GraphDef& graph_def, const Config& config, + const MainFlags& flags, const FunctionLibraryDefinition* flib, + std::unique_ptr* graph); + +// CompileResult describes the output of CompileGraph, where the object file +// data and meta-information is available in aot. +struct CompileResult { + // Contains object file and meta-info. + std::unique_ptr aot; + xla::ProgramShape program_shape; // Static shape of args and results. + bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext? + string entry_point; // Name of generated function. + int pointer_size = 0; // Size of a pointer in bytes. +}; + +// CompileGraph compiles the graph into an object file containing a function +// that performs the graph operations. +// +// The graph must have _Arg and _Retval nodes representing the function inputs +// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr, +// value=TensorShape) representing the static shape of that input, and every +// _Retval node must point to a CompileExpressions node. +// +// Typically InitGraph is called to perform this initialization, followed by +// full specification of the shape attributes. +// +// The XLA compilation options are specified in the flags. +Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, + const FunctionLibraryDefinition* flib, + CompileResult* result); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_COMPILE_H_ diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc new file mode 100644 index 0000000000..4e3998b682 --- /dev/null +++ b/tensorflow/compiler/aot/flags.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/flags.h" + +namespace tensorflow { +namespace tfcompile { + +void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { + const std::vector tmp = { + {"graph", &flags->graph, + "Input GraphDef file. If the file ends in '.pbtxt' it is expected to " + "be in the human-readable proto text format, otherwise it is expected " + "to be in the proto binary format."}, + {"config", &flags->config, + "Input file containing Config proto. If the file ends in '.pbtxt' it " + "is expected to be in the human-readable proto text format, otherwise " + "it is expected to be in the proto binary format."}, + {"dump_fetch_nodes", &flags->dump_fetch_nodes, + "If set, only flags related to fetches are processed, and the resulting " + "fetch nodes will be dumped to stdout in a comma-separated list. " + "Typically used to format arguments for other tools, e.g. " + "freeze_graph."}, + {"debug_dir", &flags->debug_dir, + "Specifies a directory to dump debugging information, including " + "rewritten graphs and the XLA HLO module."}, + // Flags controlling the XLA ahead-of-time compilation, that correspond to + // the fields of xla::cpu::CpuAotCompilationOptions. + // + // TODO(toddw): The following flags also need to be supported: + // --xla_cpu_llvm_opt_level + // --xla_cpu_llvm_cl_opts + {"target_triple", &flags->target_triple, + "Target platform, similar to the clang -target flag. The general " + "format is ---. " + "http://clang.llvm.org/docs/CrossCompilation.html#target-triple."}, + {"target_cpu", &flags->target_cpu, + "Target cpu, similar to the clang -mcpu flag. " + "http://clang.llvm.org/docs/CrossCompilation.html#cpu-fpu-abi"}, + {"target_features", &flags->target_features, + "Target features, e.g. +avx2, +neon, etc."}, + {"entry_point", &flags->entry_point, + "Name of the generated function. If multiple generated object files " + "will be linked into the same binary, each will need a unique entry " + "point."}, + {"cpp_class", &flags->cpp_class, + "Name of the generated C++ class, wrapping the generated function. The " + "syntax of this flag is [[::],...]. " + "This mirrors the C++ syntax for referring to a class, where multiple " + "namespaces may precede the class name, separated by double-colons. " + "The class will be generated in the given namespace(s), or if no " + "namespaces are given, within the global namespace."}, + {"out_object", &flags->out_object, "Output object file name."}, + {"out_header", &flags->out_header, "Output header file name."}, + }; + flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h new file mode 100644 index 0000000000..e11a0173fa --- /dev/null +++ b/tensorflow/compiler/aot/flags.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_FLAGS_H_ +#define TENSORFLOW_COMPILER_AOT_FLAGS_H_ + +#include +#include + +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tfcompile { + +// Flags for the tfcompile binary. See *.cc file for descriptions. +struct MainFlags { + string graph; + string config; + bool dump_fetch_nodes = false; + string debug_dir; + string target_triple; + string target_cpu; + string target_features; + string entry_point; + string cpp_class; + string out_object; + string out_header; +}; + +// Appends to flag_list a tensorflow::Flag for each field in MainFlags. +void AppendMainFlags(std::vector* flag_list, MainFlags* flags); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_FLAGS_H_ diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc new file mode 100644 index 0000000000..208de5498d --- /dev/null +++ b/tensorflow/compiler/aot/runtime.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/runtime.h" + +#include + +#include "tensorflow/core/platform/dynamic_annotations.h" + +namespace tensorflow { +namespace tfcompile { +namespace runtime { + +namespace { + +// Inline memory allocation routines here, because depending on '//base' brings +// in libraries which use c++ streams, which adds considerable code size on +// android. +inline void* aligned_malloc(size_t size, int minimum_alignment) { +#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) + return memalign(minimum_alignment, size); +#else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN + void* ptr = nullptr; + // posix_memalign requires that the requested alignment be at least + // sizeof(void*). In this case, fall back on malloc which should return memory + // aligned to at least the size of a pointer. + const int required_alignment = sizeof(void*); + if (minimum_alignment < required_alignment) return malloc(size); + if (posix_memalign(&ptr, minimum_alignment, size) != 0) + return nullptr; + else + return ptr; +#endif +} + +inline void aligned_free(void* aligned_memory) { free(aligned_memory); } + +size_t align_to(size_t n, size_t align) { + return (((n - 1) / align) + 1) * align; +} + +} // namespace + +size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { + size_t total = 0; + for (size_t i = 0; i < n; ++i) { + if (sizes[i] != -1) { + total += align_to(sizes[i], kAlign); + } + } + return total; +} + +void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, + bool annotate_initialized) { + const size_t total = aligned_buffer_bytes(sizes, n); + void* contiguous = nullptr; + if (total > 0) { + contiguous = aligned_malloc(total, kAlign); + if (annotate_initialized) { + // Since the memory for temp buffers is written to by JITed code, msan has + // no way of knowing the memory was initialized, so explicitly mark it. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(contiguous, total); + } + } + uintptr_t pos = reinterpret_cast(contiguous); + for (size_t i = 0; i < n; ++i) { + if (sizes[i] == -1) { + bufs[i] = nullptr; + } else { + bufs[i] = reinterpret_cast(pos); + pos += align_to(sizes[i], kAlign); + } + } + return contiguous; +} + +void FreeContiguous(void* contiguous) { + if (contiguous != nullptr) { + aligned_free(contiguous); + } +} + +} // namespace runtime +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h new file mode 100644 index 0000000000..d085864f00 --- /dev/null +++ b/tensorflow/compiler/aot/runtime.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains utilities to make it easier to invoke functions generated +// by tfcompile. Usage of these utilities is optional. + +#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_ +#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tfcompile { +namespace runtime { + +// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +static constexpr size_t kAlign = 32; + +// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 +// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign +// byte boundaries. +size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); + +// MallocContiguousBuffers allocates buffers for use by the entry point +// generated by tfcompile. `sizes` is an array of byte sizes for each buffer, +// where -1 causes the buffer pointer to be nullptr. There are `n` entries in +// `sizes`. If `annotate_initialized` is set, the allocated memory will be +// annotated as having been initialized - this is useful when allocating +// temporary buffers. +// +// A single contiguous block of memory is allocated, and portions of it are +// parceled out into `bufs`, which must have space for `n` entries. Returns the +// head of the allocated contiguous block, which should be passed to +// FreeContiguous when the buffers are no longer in use. +void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, + bool annotate_initialized); + +// FreeContiguous frees the contiguous block of memory allocated by +// MallocContiguousBuffers. +void FreeContiguous(void* contiguous); + +} // namespace runtime +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc new file mode 100644 index 0000000000..ac79c278c1 --- /dev/null +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/runtime.h" + +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tfcompile { +namespace runtime { +namespace { + +TEST(Runtime, AlignmentValue) { + // We've chosen 32 byte alignment for the tfcompile runtime to mimic the + // regular tensorflow allocator, which was chosen to play nicely with Eigen. + // The tfcompile runtime also has a requirement that comes from the xla + // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 + // So any value that we choose must abide by that constraint as well. + EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment); +} + +TEST(Runtime, AlignedBufferBytes) { + EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0); + + static constexpr intptr_t sizesA[1] = {-1}; + EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); + + static constexpr intptr_t sizesB[1] = {3}; + EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32); + + static constexpr intptr_t sizesC[1] = {32}; + EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32); + + static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; + EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192); +} + +void* add_ptr(void* base, uintptr_t delta) { + return reinterpret_cast(reinterpret_cast(base) + delta); +} + +// To test MallocContiguousBuffers and FreeContiguous, we just check for +// expected nullptrs, and write to each byte of allocated memory. We rely on +// the leak checker to tell us if there's an inconsistency between malloc and +// free. We also check the contiguous property. +TEST(Runtime, MallocFreeContiguousBuffers) { + // Test empty sizes. + void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false); + EXPECT_EQ(base, nullptr); + FreeContiguous(base); + + // Test non-empty sizes with 0 sum. + static constexpr intptr_t sizesA[1] = {-1}; + void* bufA[1]; + base = MallocContiguousBuffers(sizesA, 1, bufA, false); + EXPECT_EQ(base, nullptr); + EXPECT_EQ(bufA[0], nullptr); + FreeContiguous(base); + + // Test non-empty sizes with non-0 sum. + static constexpr intptr_t sizesB[1] = {3}; + void* bufB[1]; + base = MallocContiguousBuffers(sizesB, 1, bufB, false); + EXPECT_NE(base, nullptr); + EXPECT_EQ(bufB[0], add_ptr(base, 0)); + char* bufB0_bytes = static_cast(bufB[0]); + bufB0_bytes[0] = 'A'; + bufB0_bytes[1] = 'B'; + bufB0_bytes[2] = 'C'; + FreeContiguous(base); + + // Test non-empty sizes with non-0 sum, and annotate_initialized. + static constexpr intptr_t sizesC[1] = {3}; + void* bufC[1]; + base = MallocContiguousBuffers(sizesC, 1, bufC, true); + EXPECT_NE(base, nullptr); + EXPECT_EQ(bufC[0], add_ptr(base, 0)); + char* bufC0_bytes = static_cast(bufC[0]); + bufC0_bytes[0] = 'A'; + bufC0_bytes[1] = 'B'; + bufC0_bytes[2] = 'C'; + FreeContiguous(base); + + // Test mixed sizes. + static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; + void* bufD[7]; + base = MallocContiguousBuffers(sizesD, 7, bufD, false); + EXPECT_NE(base, nullptr); + EXPECT_EQ(bufD[0], add_ptr(base, 0)); + EXPECT_EQ(bufD[1], nullptr); + EXPECT_EQ(bufD[2], add_ptr(base, 32)); + EXPECT_EQ(bufD[3], nullptr); + EXPECT_EQ(bufD[4], add_ptr(base, 64)); + EXPECT_EQ(bufD[5], add_ptr(base, 128)); + EXPECT_EQ(bufD[6], add_ptr(base, 160)); + for (int i = 0; i < 7; ++i) { + const intptr_t size = sizesD[i]; + if (size != -1) { + char* bufD_bytes = static_cast(bufD[i]); + for (size_t j = 0; j < size; ++j) { + bufD_bytes[j] = 'A' + j; + } + } + } + FreeContiguous(base); +} + +} // namespace +} // namespace runtime +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc new file mode 100644 index 0000000000..47ef5f82cb --- /dev/null +++ b/tensorflow/compiler/aot/test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generated by the tf_library build rule. DO NOT EDIT! +// +// This file contains a test and benchmark for the function generated by +// tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be rewritten to +// real values before this file can be compiled. +// +// TFCOMPILE_HEADER : Path to the header file generated by tfcompile. +// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile. +// TFCOMPILE_NAME : Name for tests and benchmarks. +// +// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and +// generates a cc_test rule for you. + +// These macros must be defined before eigen files are included. +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +// clang-format off +#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces) +// clang-format on + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +// Macros that expand to tokens based on the entry point name. +// clang-format off +#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces) +#define TEST_NAME {{TFCOMPILE_NAME}}Test // NOLINT(whitespace/braces) +#define BM_NAME BM_{{TFCOMPILE_NAME}} // NOLINT(whitespace/braces) +// clang-format on + +namespace tensorflow { +namespace tfcompile { +namespace { + +void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) { + for (int i = 0; i < n; ++i) { + if (sizes[i] != -1) { + memset(bufs[i], 0, sizes[i]); + } + } +} + +// Trivial test that runs the generated function to ensure it doesn't crash. +TEST(TEST_NAME, NoCrash) { + Eigen::ThreadPool pool(port::NumSchedulableCPUs()); + Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + + CPP_CLASS computation; + computation.set_thread_pool(&device); + zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + + EXPECT_TRUE(computation.Run()); +} + +// Simple benchmark that repeatedly runs the generated function. +void BM_NAME(int iters) { + testing::StopTiming(); + + Eigen::ThreadPool pool(port::NumSchedulableCPUs()); + Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + + CPP_CLASS computation; + computation.set_thread_pool(&device); + zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + + testing::StartTiming(); + while (--iters) { + computation.Run(); + } + testing::StopTiming(); +} +BENCHMARK(BM_NAME); + +} // namespace +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt new file mode 100644 index 0000000000..5625c0ab03 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} diff --git a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt new file mode 100644 index 0000000000..91c900e06d --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt @@ -0,0 +1,63 @@ +node { + name : "x_const" + op : "Const" + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + attr { + key : "dtype" + value { + type : DT_INT32 + } + } +} +node { + name : "y_const" + op : "Const" + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name : "x_y_sum" + op : "Add" + input : "x_const" + input : "y_const" + attr { + key : "T" + value { + type: DT_INT32 + } + } +} +versions { + producer: 15 +} diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD new file mode 100644 index 0000000000..ecb071a416 --- /dev/null +++ b/tensorflow/compiler/aot/tests/BUILD @@ -0,0 +1,146 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//visibility:private"], +) + +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +test_suite( + name = "all_tests", + tags = ["manual"], + tests = [ + ":test_graph_tfadd_test", + ":test_graph_tfadd_with_ckpt_saver_test", + ":test_graph_tfadd_with_ckpt_test", + ":test_graph_tfgather_test", + ":test_graph_tfmatmul_test", + ":test_graph_tfmatmulandadd_test", + ":tfcompile_test", + ], +) + +py_binary( + name = "make_test_graphs", + testonly = 1, + srcs = ["make_test_graphs.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python", # TODO(b/34059704): remove when fixed + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + +genrule( + name = "gen_test_graphs", + testonly = 1, + outs = [ + "test_graph_tfadd.pb", + "test_graph_tfadd_with_ckpt.pb", + "test_graph_tfadd_with_ckpt.ckpt", + "test_graph_tfadd_with_ckpt_saver.pb", + "test_graph_tfadd_with_ckpt_saver.ckpt", + "test_graph_tfadd_with_ckpt_saver.saver", + "test_graph_tfgather.pb", + "test_graph_tfmatmul.pb", + "test_graph_tfmatmulandadd.pb", + ], + cmd = "$(location :make_test_graphs) --out_dir $(@D)", + tags = ["manual"], + tools = [":make_test_graphs"], +) + +tf_library( + name = "test_graph_tfadd", + testonly = 1, + config = "test_graph_tfadd.config.pbtxt", + cpp_class = "AddComp", + graph = "test_graph_tfadd.pb", + tags = ["manual"], +) + +tf_library( + name = "test_graph_tfadd_with_ckpt", + testonly = 1, + config = "test_graph_tfadd_with_ckpt.config.pbtxt", + cpp_class = "AddWithCkptComp", + freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt", + graph = "test_graph_tfadd_with_ckpt.pb", + tags = ["manual"], +) + +tf_library( + name = "test_graph_tfadd_with_ckpt_saver", + testonly = 1, + config = "test_graph_tfadd_with_ckpt.config.pbtxt", + cpp_class = "AddWithCkptSaverComp", + freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt", + freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver", + graph = "test_graph_tfadd_with_ckpt_saver.pb", + tags = ["manual"], +) + +tf_library( + name = "test_graph_tfgather", + testonly = 1, + config = "test_graph_tfgather.config.pbtxt", + cpp_class = "GatherComp", + graph = "test_graph_tfgather.pb", + tags = ["manual"], +) + +tf_library( + name = "test_graph_tfmatmul", + testonly = 1, + config = "test_graph_tfmatmul.config.pbtxt", + cpp_class = "foo::bar::MatMulComp", + graph = "test_graph_tfmatmul.pb", + tags = ["manual"], +) + +tf_library( + name = "test_graph_tfmatmulandadd", + testonly = 1, + config = "test_graph_tfmatmulandadd.config.pbtxt", + cpp_class = "MatMulAndAddComp", + graph = "test_graph_tfmatmulandadd.pb", + tags = ["manual"], +) + +cc_test( + name = "tfcompile_test", + srcs = ["tfcompile_test.cc"], + tags = ["manual"], + deps = [ + ":test_graph_tfadd", + ":test_graph_tfadd_with_ckpt", + ":test_graph_tfadd_with_ckpt_saver", + ":test_graph_tfgather", + ":test_graph_tfmatmul", + ":test_graph_tfmatmulandadd", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py new file mode 100644 index 0000000000..261dfcbdf8 --- /dev/null +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generate tensorflow graphs for testing tfcompile.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import app +from tensorflow.python.platform import flags as flags_lib +from tensorflow.python.training import saver as saver_lib + +flags = flags_lib +FLAGS = flags.FLAGS +flags.DEFINE_string('out_dir', '', + 'Output directory for graphs, checkpoints and savers.') + + +def tfadd(): + x = constant_op.constant([1], name='x_const') + y = constant_op.constant([2], name='y_const') + math_ops.add(x, y, name='x_y_sum') + + +def tfadd_with_ckpt(): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = variables.Variable(constant_op.constant([0]), name='y_saved') + math_ops.add(x, y, name='x_y_sum') + + init_op = variables.initialize_all_variables() + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) + with session.Session() as sess: + sess.run(init_op) + sess.run(y.assign(y + 42)) + # Without the checkpoint, the variable won't be set to 42. + ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir + saver.save(sess, ckpt) + + +def tfadd_with_ckpt_saver(): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = variables.Variable(constant_op.constant([0]), name='y_saved') + math_ops.add(x, y, name='x_y_sum') + + init_op = variables.initialize_all_variables() + saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) + with session.Session() as sess: + sess.run(init_op) + sess.run(y.assign(y + 42)) + # Without the checkpoint, the variable won't be set to 42. + ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir + saver.save(sess, ckpt_file) + # Without the SaverDef, the restore op won't be named correctly. + saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir + with open(saver_file, 'w') as f: + f.write(saver.as_saver_def().SerializeToString()) + + +def tfgather(): + params = array_ops.placeholder(dtypes.float32, name='params') + indices = array_ops.placeholder(dtypes.int32, name='indices') + array_ops.gather(params, indices, name='gather_output') + + +def tfmatmul(): + x = array_ops.placeholder(dtypes.float32, name='x_hold') + y = array_ops.placeholder(dtypes.float32, name='y_hold') + math_ops.matmul(x, y, name='x_y_prod') + + +def tfmatmulandadd(): + # This tests multiple outputs. + x = array_ops.placeholder(dtypes.float32, name='x_hold') + y = array_ops.placeholder(dtypes.float32, name='y_hold') + math_ops.matmul(x, y, name='x_y_prod') + math_ops.add(x, y, name='x_y_sum') + + +def write_graph(build_graph): + """Build a graph using build_graph and write it out.""" + g = ops.Graph() + with g.as_default(): + build_graph() + filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__) + with open(filename, 'w') as f: + f.write(g.as_graph_def().SerializeToString()) + + +def main(_): + write_graph(tfadd) + write_graph(tfadd_with_ckpt) + write_graph(tfadd_with_ckpt_saver) + write_graph(tfgather) + write_graph(tfmatmul) + write_graph(tfmatmulandadd) + + +if __name__ == '__main__': + app.run() diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt new file mode 100644 index 0000000000..5625c0ab03 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt new file mode 100644 index 0000000000..4d876a6e91 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt @@ -0,0 +1,10 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt new file mode 100644 index 0000000000..648ee31fdb --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "params" } + shape { + dim { size: 4 } + } +} +feed { + id { node_name: "indices" } + shape { + dim { size: 2 } + } +} +fetch { + id { node_name: "gather_output" } +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt new file mode 100644 index 0000000000..a3ce2029c1 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt @@ -0,0 +1,18 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 2 } + dim { size: 3 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 3 } + dim { size: 2 } + } +} +fetch { + id { node_name: "x_y_prod" } +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt new file mode 100644 index 0000000000..4a4a237a4f --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt @@ -0,0 +1,25 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 2 } + dim { size: 2 } + } + name: "x" +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 2 } + dim { size: 2 } + } + name: "y" +} +fetch { + id { node_name: "x_y_prod" } + name: "x_y_prod" +} +fetch { + id { node_name: "x_y_sum" } + name: "x_y_sum" +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc new file mode 100644 index 0000000000..f57d2859df --- /dev/null +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -0,0 +1,381 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tfcompile { +namespace { + +TEST(TFCompileTest, Add) { + AddComp add; + EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg1_data(), add.args()[1]); + + add.arg0() = 1; + add.arg1() = 2; + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 3); + EXPECT_EQ(add.result0_data()[0], 3); + EXPECT_EQ(add.result0_data(), add.results()[0]); + + add.arg0_data()[0] = 123; + add.arg1_data()[0] = 456; + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 579); + EXPECT_EQ(add.result0_data()[0], 579); + EXPECT_EQ(add.result0_data(), add.results()[0]); + + const AddComp& add_const = add; + EXPECT_EQ(add_const.error_msg(), ""); + EXPECT_EQ(add_const.arg0(), 123); + EXPECT_EQ(add_const.arg0_data()[0], 123); + EXPECT_EQ(add_const.arg0_data(), add.args()[0]); + EXPECT_EQ(add_const.arg1(), 456); + EXPECT_EQ(add_const.arg1_data()[0], 456); + EXPECT_EQ(add_const.arg1_data(), add.args()[1]); + EXPECT_EQ(add_const.result0(), 579); + EXPECT_EQ(add_const.result0_data()[0], 579); + EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); +} + +// Run tests that use set_argN_data separately, to avoid accidentally re-using +// non-existent buffers. +TEST(TFCompileTest, Add_SetArg) { + AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + + int32 arg_x = 10; + int32 arg_y = 32; + add.set_arg0_data(&arg_x); + add.set_arg1_data(&arg_y); + EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg1_data(), add.args()[1]); + + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 42); + EXPECT_EQ(add.result0_data()[0], 42); + EXPECT_EQ(add.result0_data(), add.results()[0]); +} + +TEST(TFCompileTest, AddWithCkpt) { + AddWithCkptComp add; + EXPECT_EQ(add.arg0_data(), add.args()[0]); + + add.arg0() = 1; + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 43); + EXPECT_EQ(add.result0_data()[0], 43); + EXPECT_EQ(add.result0_data(), add.results()[0]); + + add.arg0_data()[0] = 111; + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 153); + EXPECT_EQ(add.result0_data()[0], 153); + EXPECT_EQ(add.result0_data(), add.results()[0]); + + const AddWithCkptComp& add_const = add; + EXPECT_EQ(add_const.error_msg(), ""); + EXPECT_EQ(add_const.arg0(), 111); + EXPECT_EQ(add_const.arg0_data()[0], 111); + EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.result0(), 153); + EXPECT_EQ(add_const.result0_data()[0], 153); + EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); +} + +TEST(TFCompileTest, AddWithCkptSaver) { + AddWithCkptSaverComp add; + EXPECT_EQ(add.arg0_data(), add.args()[0]); + + add.arg0() = 1; + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 43); + EXPECT_EQ(add.result0_data()[0], 43); + EXPECT_EQ(add.result0_data(), add.results()[0]); + + add.arg0_data()[0] = 111; + EXPECT_TRUE(add.Run()); + EXPECT_EQ(add.error_msg(), ""); + EXPECT_EQ(add.result0(), 153); + EXPECT_EQ(add.result0_data()[0], 153); + EXPECT_EQ(add.result0_data(), add.results()[0]); + + const AddWithCkptSaverComp& add_const = add; + EXPECT_EQ(add_const.error_msg(), ""); + EXPECT_EQ(add_const.arg0(), 111); + EXPECT_EQ(add_const.arg0_data()[0], 111); + EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.result0(), 153); + EXPECT_EQ(add_const.result0_data()[0], 153); + EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); +} + +TEST(TFCompileTest, Gather) { + GatherComp gather; + EXPECT_EQ(gather.arg0_data(), gather.args()[0]); + EXPECT_EQ(gather.arg1_data(), gather.args()[1]); + + // Successful gather. + { + const float params[4] = {1, 2, 3, 4}; + std::copy(params + 0, params + 4, gather.arg0_data()); + const int32 indices[2] = {1, 3}; + std::copy(indices + 0, indices + 2, gather.arg1_data()); + EXPECT_TRUE(gather.Run()); + EXPECT_EQ(gather.error_msg(), ""); + const float results[2] = {2, 4}; + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(gather.result0(i), results[i]); + EXPECT_EQ(gather.result0_data()[i], results[i]); + } + EXPECT_EQ(gather.result0_data(), gather.results()[0]); + + const GatherComp& gather_const = gather; + EXPECT_EQ(gather_const.error_msg(), ""); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(gather_const.arg0(i), params[i]); + EXPECT_EQ(gather_const.arg0_data()[i], params[i]); + } + EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(gather_const.arg1(i), indices[i]); + EXPECT_EQ(gather_const.arg1_data()[i], indices[i]); + } + EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(gather_const.result0(i), results[i]); + EXPECT_EQ(gather_const.result0_data()[i], results[i]); + } + EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); + } + + // Bad indices returns an error. + { + const float params[4] = {1, 2, 3, 4}; + std::copy(params + 0, params + 4, gather.arg0_data()); + const int32 indices[2] = {1, 4}; + std::copy(indices + 0, indices + 2, gather.arg1_data()); + EXPECT_FALSE(gather.Run()); + EXPECT_EQ(gather.error_msg(), "Invalid index for gather"); + } +} + +TEST(TFCompileTest, MatMul2) { + Eigen::ThreadPool tp(2); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + foo::bar::MatMulComp matmul; + matmul.set_thread_pool(&device); + EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); + EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + + // Test using the argN() methods. + { + matmul.arg0(0, 0) = 1; + matmul.arg0(0, 1) = 2; + matmul.arg0(0, 2) = 3; + matmul.arg0(1, 0) = 4; + matmul.arg0(1, 1) = 5; + matmul.arg0(1, 2) = 6; + + matmul.arg1(0, 0) = 7; + matmul.arg1(0, 1) = 8; + matmul.arg1(1, 0) = 9; + matmul.arg1(1, 1) = 10; + matmul.arg1(2, 0) = 11; + matmul.arg1(2, 1) = 12; + + EXPECT_TRUE(matmul.Run()); + EXPECT_EQ(matmul.error_msg(), ""); + const float results[4] = {58, 64, 139, 154}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]); + EXPECT_EQ(matmul.result0_data()[i], results[i]); + } + EXPECT_EQ(matmul.result0_data(), matmul.results()[0]); + } + + // Test using the argN_data() methods. + { + const float args[12] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}; + std::copy(args + 0, args + 6, matmul.arg0_data()); + std::copy(args + 6, args + 12, matmul.arg1_data()); + EXPECT_TRUE(matmul.Run()); + EXPECT_EQ(matmul.error_msg(), ""); + const float results[4] = {5800, 6400, 13900, 15400}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]); + EXPECT_EQ(matmul.result0_data()[i], results[i]); + } + EXPECT_EQ(matmul.result0_data(), matmul.results()[0]); + + const foo::bar::MatMulComp& matmul_const = matmul; + EXPECT_EQ(matmul_const.error_msg(), ""); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]); + EXPECT_EQ(matmul_const.arg0_data()[i], args[i]); + } + EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]); + EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]); + } + EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]); + EXPECT_EQ(matmul_const.result0_data()[i], results[i]); + } + EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]); + } +} + +// Run tests that use set_argN_data separately, to avoid accidentally re-using +// non-existent buffers. +TEST(TFCompileTest, MatMul2_SetArg) { + Eigen::ThreadPool tp(2); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + foo::bar::MatMulComp matmul( + foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + matmul.set_thread_pool(&device); + + // Test using the set_argN_data() methods. + float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}}; + float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; + matmul.set_arg0_data(&arg0); + matmul.set_arg1_data(&arg1); + EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); + EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + + EXPECT_TRUE(matmul.Run()); + EXPECT_EQ(matmul.error_msg(), ""); + const float results[4] = {58, 64, 139, 154}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]); + EXPECT_EQ(matmul.result0_data()[i], results[i]); + } + EXPECT_EQ(matmul.result0_data(), matmul.results()[0]); +} + +TEST(TFCompileTest, MatMulAndAdd1) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + MatMulAndAddComp muladd; + muladd.set_thread_pool(&device); + EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]); + EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]); + + // Test methods with positional args and results. + { + const float args[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + std::copy(args + 0, args + 4, muladd.arg0_data()); + std::copy(args + 4, args + 8, muladd.arg1_data()); + EXPECT_TRUE(muladd.Run()); + EXPECT_EQ(muladd.error_msg(), ""); + const float results0[4] = {19, 22, 43, 50}; + const float results1[4] = {6, 8, 10, 12}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd.result0(i / 2, i % 2), results0[i]); + EXPECT_EQ(muladd.result0_data()[i], results0[i]); + EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]); + EXPECT_EQ(muladd.result1_data()[i], results1[i]); + } + EXPECT_EQ(muladd.result0_data(), muladd.results()[0]); + EXPECT_EQ(muladd.result1_data(), muladd.results()[1]); + + const MatMulAndAddComp& muladd_const = muladd; + EXPECT_EQ(muladd_const.error_msg(), ""); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]); + EXPECT_EQ(muladd_const.arg0_data()[i], args[i]); + } + EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]); + EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]); + } + EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]); + EXPECT_EQ(muladd_const.result0_data()[i], results0[i]); + EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]); + EXPECT_EQ(muladd_const.result1_data()[i], results1[i]); + } + EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]); + EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]); + } + + // Test methods with named args and results. + { + const float args[8] = {10, 20, 30, 40, 50, 60, 70, 80}; + std::copy(args + 0, args + 4, muladd.arg_x_data()); + std::copy(args + 4, args + 8, muladd.arg_y_data()); + EXPECT_TRUE(muladd.Run()); + EXPECT_EQ(muladd.error_msg(), ""); + const float results0[4] = {1900, 2200, 4300, 5000}; + const float results1[4] = {60, 80, 100, 120}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd.result_x_y_prod(i / 2, i % 2), results0[i]); + EXPECT_EQ(muladd.result_x_y_prod_data()[i], results0[i]); + EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]); + EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]); + } + EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]); + EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]); + + // Test const methods. + const MatMulAndAddComp& muladd_const = muladd; + EXPECT_EQ(muladd_const.error_msg(), ""); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]); + EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]); + } + EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]); + EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]); + } + EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]); + EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]); + EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]); + EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]); + } + EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]); + EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]); + } +} + +} // namespace +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl new file mode 100644 index 0000000000..6f2e0958fd --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -0,0 +1,285 @@ +# -*- Python -*- + +"""Build macro that compiles a TensorFlow graph into a cc_library. + +To use from your BUILD file, add the following line to load the macro: + +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +Then call the macro like this: + +tf_library( + name = "test_graph_tfmatmul", + config = "test_graph_tfmatmul.config.pbtxt", + cpp_class = "MatMulComp", + graph = ":test_graph_tfmatmul.pb", +) +""" + +load("//tensorflow:tensorflow.bzl", "if_android", "tf_copts") + +def tf_library(name, graph, config, + freeze_checkpoint=None, freeze_saver=None, + cpp_class=None, gen_test=True, gen_benchmark=True, + visibility=None, testonly=None, + tfcompile_flags=None, + tfcompile_tool="//tensorflow/compiler/aot:tfcompile", + deps=None, tags=None): + """Runs tfcompile to compile a TensorFlow graph into executable code. + + Args: + name: The name of the build rule. + graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it + is expected to be in the human-readable proto text format, otherwise it is + expected to be in the proto binary format. + config: File containing tensorflow.tfcompile.Config proto. If the file ends + in '.pbtxt' it is expected to be in the human-readable proto text format, + otherwise it is expected to be in the proto binary format. + freeze_checkpoint: If provided, run freeze_graph with this checkpoint to + convert variables into constants. + freeze_saver: If provided, run freeze_graph with this saver, in SaverDef + binary form, to convert variables into constants. + cpp_class: The name of the generated C++ class, wrapping the generated + function. The syntax of this flag is + [[::],...]. This mirrors the C++ syntax + for referring to a class, where multiple namespaces may precede the class + name, separated by double-colons. The class will be generated in the + given namespace(s), or if no namespaces are given, within the global + namespace. + gen_test: If True, also generate a cc_test rule that builds a simple + test and benchmark. + gen_benchmark: If True, also generate a binary with a simple benchmark. + Unlike the output of gen_test, this benchmark can be run on android. + visibility: Bazel build visibility. + testonly: Bazel testonly attribute. + tfcompile_flags: Extra flags to pass to tfcompile to control compilation. + tfcompile_tool: The tfcompile binary. A non-default can be passed to + use a tfcompile built with extra dependencies. + deps: a list of extra deps to include on the build rules for + the generated library. + tags: tags to apply to subsidiary build rules. + + The output header is called .h. + """ + if not cpp_class: + fail("cpp_class must be specified") + + tfcompile_graph = graph + if freeze_checkpoint or freeze_saver: + if not freeze_checkpoint: + fail("freeze_checkpoint must be specified when freeze_saver is specified") + + freeze_name = "freeze_" + name + freeze_file = freeze_name + ".pb" + + # First run tfcompile to generate the list of out_nodes. + out_nodes_file = "out_nodes_" + freeze_name + native.genrule( + name=("gen_" + out_nodes_file), + srcs=[config], + outs=[out_nodes_file], + cmd=("$(location " + tfcompile_tool + ")" + + " --config=$(location " + config + ")" + + " --dump_fetch_nodes > $@"), + tools=[tfcompile_tool], + # Run tfcompile on the build host, rather than forge, since it's + # typically way faster on the local machine. + local=1, + tags=tags, + ) + + # Now run freeze_graph to convert variables into constants. + freeze_args = (" --input_graph=$(location " + graph + ")" + + " --input_binary=" + str(not graph.endswith(".pbtxt")) + + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + + " --output_graph=$(location " + freeze_file + ")" + + " --output_node_names=$$(<$(location " + out_nodes_file + + "))") + freeze_saver_srcs = [] + if freeze_saver: + freeze_args += " --input_saver=$(location " + freeze_saver + ")" + freeze_saver_srcs += [freeze_saver] + native.genrule( + name=freeze_name, + srcs=[ + graph, + freeze_checkpoint, + out_nodes_file, + ] + freeze_saver_srcs, + outs=[freeze_file], + cmd=("$(location //tensorflow/python/tools:freeze_graph)" + + freeze_args), + tools=["//tensorflow/python/tools:freeze_graph"], + tags=tags, + ) + tfcompile_graph = freeze_file + + # Rule that runs tfcompile to produce the header and object file. + header_file = name + ".h" + object_file = name + ".o" + ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + native.genrule( + name=("gen_" + name), + srcs=[ + tfcompile_graph, + config, + ], + outs=[ + header_file, + object_file, + ], + cmd=("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_object=$(@D)/" + object_file + + " " + (tfcompile_flags or "")), + tools=[tfcompile_tool], + visibility=visibility, + testonly=testonly, + # Run tfcompile on the build host since it's typically faster on the local + # machine. + # + # Note that setting the local=1 attribute on a *test target* causes the + # test infrastructure to skip that test. However this is a genrule, not a + # test target, and runs with --genrule_strategy=forced_forge, meaning the + # local=1 attribute is ignored, and the genrule is still run. + # + # https://www.bazel.io/versions/master/docs/be/general.html#genrule + local=1, + tags=tags, + ) + + # The cc_library rule packaging up the header and object file, and needed + # kernel implementations. + native.cc_library( + name=name, + srcs=[object_file], + hdrs=[header_file], + visibility=visibility, + testonly=testonly, + deps = [ + # TODO(cwhipkey): only depend on kernel code that the model actually needed. + "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32", + "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + "//tensorflow/core:framework_lite", + ] + (deps or []), + tags=tags, + ) + + # Variables used for gen_test and gen_benchmark. + no_ns_name = "" + cpp_class_split = cpp_class.rsplit("::", maxsplit=2) + if len(cpp_class_split) == 1: + no_ns_name = cpp_class_split[0] + else: + no_ns_name = cpp_class_split[1] + sed_replace = ( + "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + + "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + + "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ") + + if gen_test: + test_name = name + "_test" + test_file = test_name + ".cc" + # Rule to rewrite test.cc to produce the test_file. + native.genrule( + name=("gen_" + test_name), + testonly=1, + srcs=[ + "//tensorflow/compiler/aot:test.cc", + header_file, + ], + outs=[test_file], + cmd=("sed " + sed_replace + + " $(location //tensorflow/compiler/aot:test.cc) " + + "> $(OUTS)"), + tags=tags, + ) + + # The cc_test rule for the generated code. + native.cc_test( + name=test_name, + srcs=[test_file], + deps=[ + ":" + name, + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + tags=tags, + ) + + if gen_benchmark: + benchmark_name = name + "_benchmark" + benchmark_file = benchmark_name + ".cc" + benchmark_main = ("//tensorflow/compiler/aot:" + + "benchmark_main.template") + + # Rule to rewrite benchmark.cc to produce the benchmark_file. + native.genrule( + name=("gen_" + benchmark_name), + srcs=[ + benchmark_main, + header_file, + ], + testonly = testonly, + outs=[benchmark_file], + cmd=("sed " + sed_replace + + " $(location " + benchmark_main + ") " + + "> $(OUTS)"), + tags=tags, + ) + + # The cc_benchmark rule for the generated code. + # + # Note: to get smaller size on android for comparison, compile with: + # --copt=-fvisibility=hidden + # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN + # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN + native.cc_binary( + name=benchmark_name, + srcs=[benchmark_file], + testonly = testonly, + copts = tf_copts(), + linkopts = if_android(["-pie", "-s"]), + deps=[ + ":" + name, + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + ] + if_android([ + "//tensorflow/compiler/aot:benchmark_extra_android", + ]), + tags=tags, + ) + + +def target_llvm_triple(): + """Returns the target LLVM triple to be used for compiling the target.""" + # TODO(toddw): Add target_triple for other targets. For details see: + # http://llvm.org/docs/doxygen/html/Triple_8h_source.html + return select({ + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//conditions:default": "x86_64-pc-linux", + }) diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/aot/tfcompile.proto new file mode 100644 index 0000000000..be3f504350 --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile.proto @@ -0,0 +1,43 @@ +syntax = "proto3"; + +package tensorflow.tfcompile; +option cc_enable_arenas = true; +option java_outer_classname = "CompileProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.tfcompile"; + +import "tensorflow/core/framework/tensor_shape.proto"; + +// TensorId identifies a tensor in a TensorFlow graph, by specifying the output +// index of a particular node in the graph. If the output of the named node +// feeds into other node(s), this corresponds to one or more edges. Otherwise +// it doesn't correspond to any existing edges at all, e.g. for output nodes. +message TensorId { + string node_name = 1; + int64 output_index = 2; +}; + +// Feed represents a single feed tensor in the graph, which corresponds to an +// input argument for the generated function. +message Feed { + TensorId id = 1; + TensorShapeProto shape = 2; + string name = 3; // Optional name for generated code. +}; + +// Fetch represents a single fetch tensor in the graph, which corresponds to an +// output argument for the generated function. +message Fetch { + TensorId id = 1; + string name = 2; // Optional name for generated code. +}; + +// Config represents configuration information for tfcompile. +message Config { + // Each feed is a positional input argument for the generated function. The + // order of each entry matches the order of each input argument. + repeated Feed feed = 1; + // Each fetch is a positional output argument for the generated function. The + // order of each entry matches the order of each output argument. + repeated Fetch fetch = 2; +}; diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc new file mode 100644 index 0000000000..85ef9560bb --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/aot/codegen.h" +#include "tensorflow/compiler/aot/compile.h" +#include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tfcompile { + +const char kUsageHeader[] = + "tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n" + "resulting in an object file compiled for your target architecture, and a\n" + "header file that gives access to the functionality in the object file.\n" + "A typical invocation looks like this:\n" + "\n" + " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n" + "\n"; + +Status ReadProtoFile(const string& kind, const string& fname, + protobuf::Message* proto) { + if (StringPiece(fname).ends_with(".pbtxt")) { + return ReadTextProto(Env::Default(), fname, proto); + } else { + return ReadBinaryProto(Env::Default(), fname, proto); + } +} + +void ParseTensorId(const string& name, TensorId* id) { + const std::pair name_index = ParseTensorName(name); + id->set_node_name(name_index.first.ToString()); + id->set_output_index(name_index.second); +} + +Status Main(const MainFlags& flags) { + // Process config. + Config config; + TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config)); + TF_RETURN_IF_ERROR(ValidateConfig(config)); + if (flags.dump_fetch_nodes) { + std::set nodes; + for (const Fetch& fetch : config.fetch()) { + nodes.insert(fetch.id().node_name()); + } + std::cout << str_util::Join(nodes, ","); + return Status::OK(); + } + + // Read and initialize the graph. + GraphDef graph_def; + TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def)); + std::unique_ptr graph; + FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library()); + TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph)); + + CompileResult compile_result; + TF_RETURN_IF_ERROR( + CompileGraph(std::move(graph), flags, &flib, &compile_result)); + + // Write output files. + Env* env = Env::Default(); + const std::vector& obj = compile_result.aot->object_file_data(); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, + StringPiece(obj.data(), obj.size()))); + HeaderOpts header_opts; + TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, + &header_opts.namespaces)); + string header; + TF_RETURN_IF_ERROR( + GenerateHeader(header_opts, config, compile_result, &header)); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); + return Status::OK(); +} + +} // end namespace tfcompile +} // end namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::tfcompile::MainFlags flags; + flags.target_triple = "x86_64-pc-linux"; + flags.out_object = "out.o"; + flags.out_header = "out.h"; + + std::vector flag_list; + AppendMainFlags(&flag_list, &flags); + xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list); + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); + + tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; + usage += tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1 && !flags.config.empty() && + (flags.dump_fetch_nodes || + (!flags.graph.empty() && !flags.entry_point.empty()))) + << "\n" + << usage; + + TF_QCHECK_OK(tensorflow::tfcompile::Main(flags)); + return 0; +} diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc new file mode 100644 index 0000000000..fd073a2e26 --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile_util.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/tfcompile_util.h" + +#include + +#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace tfcompile { + +namespace { + +bool IsAlpha(char c) { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); +} + +bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } + +Status ValidateTensorId(const TensorId& id) { + if (id.node_name().empty()) { + return errors::InvalidArgument("TensorId node_name must be non-empty"); + } + if (id.output_index() < 0) { + return errors::InvalidArgument("TensorId output_index must be positive"); + } + return Status::OK(); +} + +Status ValidateFeedFetchName(const string& kind, const string& name, + std::set* names) { + if (!name.empty()) { + TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name")); + if (!names->insert(name).second) { + return errors::InvalidArgument("duplicate ", kind, " name: ", name); + } + } + return Status::OK(); +} + +Status CheckFeedFetchNameConflicts(const string& kind, + const std::set& names) { + // We don't allow the feeds or fetches to contain both "foo" and "foo_data", + // since that will cause a collision in codegen symbols. + for (const string& name : names) { + const string name_data(name + "_data"); + if (names.find(name_data) != names.end()) { + return errors::InvalidArgument("conflicting ", kind, " name: ", name, + " and ", name_data); + } + } + return Status::OK(); +} + +} // namespace + +Status ValidateCppIdent(StringPiece ident, StringPiece msg) { + if (ident.empty()) { + return errors::InvalidArgument("empty identifier: ", msg); + } + // Require that the identifier starts with a nondigit, and is composed of + // nondigits and digits, as specified in section [2.11 Identifiers] of the + // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is + // defined as [0-9]. + // + // Technically the standard also allows for `universal-character-name`, with a + // table of allowed unicode ranges, as well as `other implementation-defined + // characters`. We disallow those here to give better error messages, at the + // expensive of being more restrictive than the standard. + if (ident[0] != '_' && !IsAlpha(ident[0])) { + return errors::InvalidArgument("illegal leading char: ", msg); + } + for (size_t pos = 1; pos < ident.size(); ++pos) { + if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) { + return errors::InvalidArgument("illegal char: ", msg); + } + } + return Status::OK(); +} + +Status ValidateConfig(const Config& config) { + std::set names; + for (const Feed& feed : config.feed()) { + TF_RETURN_IF_ERROR(ValidateTensorId(feed.id())); + TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape())); + TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names)); + } + TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names)); + names.clear(); + for (const Fetch& fetch : config.fetch()) { + TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id())); + TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names)); + } + TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); + if (config.feed().empty() || config.fetch().empty()) { + return errors::InvalidArgument("feeds and fetches must be specified"); + } + return Status::OK(); +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h new file mode 100644 index 0000000000..651d75d0d0 --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile_util.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ +#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ + +#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace tfcompile { + +// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is +// appended to error messages. +Status ValidateCppIdent(StringPiece ident, StringPiece msg); + +// ValidateConfig returns OK iff config is valid. +Status ValidateConfig(const Config& config); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc new file mode 100644 index 0000000000..108ab1eab7 --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile_util_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/tfcompile_util.h" + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tfcompile { +namespace { + +void ExpectErrorContains(Status status, StringPiece str) { + EXPECT_NE(Status::OK(), status); + EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + << "expected error: " << status.error_message() << " to contain: " << str; +} + +TEST(ValidateCppIdent, Simple) { + TF_EXPECT_OK(ValidateCppIdent("a", "")); + TF_EXPECT_OK(ValidateCppIdent("abc", "")); + TF_EXPECT_OK(ValidateCppIdent("_abc", "")); + TF_EXPECT_OK(ValidateCppIdent("_abc123", "")); + // Make sure we didn't skip a valid letter or digit + string ident; + for (char c = 'a'; c <= 'z'; c++) { + ident.append(1, c); + } + for (char c = 'A'; c <= 'Z'; c++) { + ident.append(1, c); + } + for (char c = '0'; c <= '9'; c++) { + ident.append(1, c); + } + ident += "_"; + TF_EXPECT_OK(ValidateCppIdent(ident, "")); + + ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier"); + ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char"); + ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); + ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); +} + +TEST(ValidateConfig, Good) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + feed->mutable_id()->set_output_index(123); + feed->set_name("foo_debug"); + feed = config.add_feed(); + feed->mutable_id()->set_node_name("bar"); + feed->mutable_id()->set_output_index(0); + Fetch* fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("baz"); + fetch->mutable_id()->set_output_index(456); + fetch->set_name("baz_debug"); + fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("banana"); + fetch->mutable_id()->set_output_index(0); + TF_EXPECT_OK(ValidateConfig(config)); +} + +TEST(ValidateConfig, BadEmpty) { + Config config; + ExpectErrorContains(ValidateConfig(config), + "feeds and fetches must be specified"); +} + +TEST(ValidateConfig, BadNoFeed) { + Config config; + Fetch* fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("foo"); + ExpectErrorContains(ValidateConfig(config), + "feeds and fetches must be specified"); +} + +TEST(ValidateConfig, BadNoFetch) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + ExpectErrorContains(ValidateConfig(config), + "feeds and fetches must be specified"); +} + +TEST(ValidateConfig, BadFeedNodeName) { + Config config; + config.add_feed(); + ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty"); +} + +TEST(ValidateConfig, BadFeedOutputIndex) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + feed->mutable_id()->set_output_index(-1); + ExpectErrorContains(ValidateConfig(config), "output_index must be positive"); +} + +TEST(ValidateConfig, BadFetchNodeName) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + config.add_fetch(); + ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty"); +} + +TEST(ValidateConfig, BadFetchOutputIndex) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + Fetch* fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("bar"); + fetch->mutable_id()->set_output_index(-1); + ExpectErrorContains(ValidateConfig(config), "output_index must be positive"); +} + +TEST(ValidateConfig, DuplicateFeedName) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + feed->set_name("dup"); + feed = config.add_feed(); + feed->mutable_id()->set_node_name("bar"); + feed->set_name("dup"); + ExpectErrorContains(ValidateConfig(config), "duplicate feed name"); +} + +TEST(ValidateConfig, DuplicateFetchName) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + Fetch* fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("bar"); + fetch->set_name("dup"); + fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("baz"); + fetch->set_name("dup"); + ExpectErrorContains(ValidateConfig(config), "duplicate fetch name"); +} + +TEST(ValidateConfig, ConflictingFeedName) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + feed->set_name("conflict"); + feed = config.add_feed(); + feed->mutable_id()->set_node_name("bar"); + feed->set_name("conflict_data"); + ExpectErrorContains(ValidateConfig(config), "conflicting feed name"); +} + +TEST(ValidateConfig, ConflictingFetchName) { + Config config; + Feed* feed = config.add_feed(); + feed->mutable_id()->set_node_name("foo"); + Fetch* fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("bar"); + fetch->set_name("conflict"); + fetch = config.add_fetch(); + fetch->mutable_id()->set_node_name("baz"); + fetch->set_name("conflict_data"); + ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); +} + +} // namespace +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD new file mode 100644 index 0000000000..d2cfca207f --- /dev/null +++ b/tensorflow/compiler/jit/BUILD @@ -0,0 +1,282 @@ +licenses(["notice"]) # Apache 2.0 + +package_group( + name = "internal", + includes = [ + "//tensorflow/compiler/tf2xla:internal", + ], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/tf2xla:friends", + ], +) + +package( + default_visibility = [":internal"], +) + +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") + +# Target that bundles up the XLA CPU and GPU JIT devices. +cc_library( + name = "jit", + visibility = [":friends"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_cpu_jit", + visibility = [":friends"], + deps = [ + ":jit_compilation_passes", + ":xla_local_launch_op", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_gpu_jit", + visibility = [":friends"], + deps = [ + ":jit_compilation_passes", + ":xla_local_launch_op", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:gpu_plugin", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_cpu_device", + srcs = ["xla_cpu_device.cc"], + visibility = [":friends"], + deps = [ + ":jit_compilation_passes", + ":xla_device", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_gpu_device", + srcs = ["xla_gpu_device.cc"], + visibility = [":friends"], + deps = [ + ":jit_compilation_passes", + ":xla_device", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +# Internal targets below this point. + +cc_library( + name = "common", + srcs = [ + "defs.cc", + ], + hdrs = [ + "defs.h", + ], + visibility = [":friends"], +) + +cc_library( + name = "xla_device", + srcs = [ + "xla_device.cc", + "xla_device_context.cc", + "xla_device_launch_op.cc", + "xla_device_ops.cc", + ], + hdrs = [ + "xla_device.h", + "xla_device_context.h", + "xla_device_launch_op.h", + "xla_device_ops.h", + ], + deps = [ + ":common", + ":jit_compilation_passes", + ":xla_compilation_cache", + ":xla_local_launch_op", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:assign_op", + "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:variable_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_compilation_cache", + srcs = ["xla_compilation_cache.cc"], + hdrs = ["xla_compilation_cache.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "jit_compilation_passes", + srcs = ["jit_compilation_pass_registration.cc"], + deps = [ + ":compilation_passes", + "//tensorflow/core:core_cpu_internal", + ], + alwayslink = 1, +) + +cc_library( + name = "compilation_passes", + srcs = [ + "build_xla_launch_ops_pass.cc", + "encapsulate_subgraphs_pass.cc", + "graph_to_functiondef.cc", + "mark_for_compilation_pass.cc", + ], + hdrs = [ + "build_xla_launch_ops_pass.h", + "encapsulate_subgraphs_pass.h", + "graph_to_functiondef.h", + "mark_for_compilation_pass.h", + ], + deps = [ + ":common", + ":parallel_check_op", + ":xla_local_launch_op", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/compiler/tf2xla:const_analysis", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "compilation_passes_test", + size = "small", + srcs = [ + "encapsulate_subgraphs_pass_test.cc", + "graph_to_functiondef_test.cc", + "mark_for_compilation_pass_test.cc", + ], + deps = [ + ":compilation_passes", + ":xla_local_launch_op", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "xla_local_launch_op", + srcs = ["xla_local_launch_op.cc"], + hdrs = ["xla_local_launch_op.h"], + deps = [ + ":common", + ":xla_compilation_cache", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:tensorflow_opensource", + ], + alwayslink = 1, +) + +tf_kernel_library( + name = "parallel_check_op", + srcs = ["parallel_check_op.cc"], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc new file mode 100644 index 0000000000..8fde197400 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/xla_local_launch_op.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +static Status BuildLaunchNode( + const string& nodename, const string& function_name, + const AttrValueMap& function_attr, const string& device_name, + const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes, + const DataTypeVector& result_dtypes, Graph* graph, Node** node) { + NodeDef def; + def.set_name(graph->NewName(nodename)); + def.set_op("_XlaLaunch"); + def.set_device(device_name); + AddNodeAttr("Tconstants", constant_dtypes, &def); + AddNodeAttr("Targs", arg_dtypes, &def); + AddNodeAttr("Tresults", result_dtypes, &def); + NameAttrList function; + function.set_name(function_name); + *function.mutable_attr() = function_attr; + AddNodeAttr("function", function, &def); + + Status status; + *node = graph->AddNode(def, &status); + return status; +} + +static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { + VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; + + int num_constant_args; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args)); + + if (num_constant_args < 0 || num_constant_args > node->input_types().size()) { + return errors::InvalidArgument( + "Invalid number of constant arguments to XLA kernel"); + } + DataTypeVector const_dtypes(node->input_types().begin(), + node->input_types().begin() + num_constant_args); + DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args, + node->input_types().end()); + + // Build a _XlaLaunch operator to execute the function body. + Node* launch_node; + TF_RETURN_IF_ERROR( + BuildLaunchNode(graph->NewName(node->name()), node->type_string(), + node->def().attr(), node->def().device(), const_dtypes, + arg_dtypes, node->output_types(), graph, &launch_node)); + launch_node->set_assigned_device_name(node->assigned_device_name()); + + // Copy incoming edges to the launch node. + for (const Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) { + graph->AddControlEdge(edge->src(), launch_node); + } else { + graph->AddEdge(edge->src(), edge->src_output(), launch_node, + edge->dst_input()); + } + } + + // Copy outgoing edges to the launch node. + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : out_edges) { + Node* dst = edge->dst(); + int src_output = edge->src_output(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (edge->IsControlEdge()) { + graph->AddControlEdge(launch_node, dst); + } else { + graph->AddEdge(launch_node, src_output, dst, dst_input); + } + } + graph->RemoveNode(node); + + return Status::OK(); +} + +Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + + for (Node* n : graph->nodes()) { + // In all cases, only try to compile computational nodes. + if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + continue; + } + + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + if (IsXlaCompiledKernel(*n)) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); + } + } + return Status::OK(); +} + +namespace { + +// Givens a NodeDef 'ndef' and the function library runtime 'flr', if +// 'ndef' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +// +// This routine is here so that FunctionLibraryRuntime can jit a +// specific function call as requested. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, + std::unique_ptr* kernel) { + bool xla_compile = false; + if (!flr->GetFunctionLibraryDefinition() + ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) + .ok() || + !xla_compile) { + // Not marked as _XlaCompile=true. + return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); + } + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterJitKernels(); + if (!IsCompilable(flr, ndef)) { + // ndef is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + } + FunctionLibraryRuntime::Handle handle; + // If ndef is not instantiable, e.g., the function does not exist, + // simply bail out. + TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle)); + const FunctionBody* fbody = flr->GetFunctionBody(handle); + CHECK(fbody); // Can't be nullptr since we just instantiated it. + std::vector const_args(fbody->arg_types.size()); + // If we can't analyze the const args. Bail out. + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + + for (int i = 0; i < const_args.size(); ++i) { + if (const_args[i]) { + // There is a const arg. Bail out. + return errors::InvalidArgument("Const arg: ", i, " in ", + DebugString(fbody->fdef)); + } + } + + NodeDef launch_def; + launch_def.set_name(ndef.name()); + launch_def.set_op("_XlaLaunch"); + launch_def.set_device(flr->device()->name()); + AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); + AddNodeAttr("Targs", fbody->arg_types, &launch_def); + AddNodeAttr("Tresults", fbody->ret_types, &launch_def); + NameAttrList func; + func.set_name(ndef.op()); + *(func.mutable_attr()) = ndef.attr(); + AddNodeAttr("function", func, &launch_def); + + // TODO(b/32387911): Handles the host memory types across function + // calls properly. For now, we assume all inputs and outputs are on + // the device memory. + MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + + Device* dev = flr->device(); + Status s; + OpKernelConstruction construction( + DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), &launch_def, + &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, + fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); + kernel->reset(new XlaLocalLaunchOp(&construction)); + return s; +} + +bool RegisterLaunchOpCreator() { + RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); + return true; +} + +static bool register_me = RegisterLaunchOpCreator(); + +} // end namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_launch_ops_pass.h new file mode 100644 index 0000000000..1dfea93f02 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class BuildXlaLaunchOpsPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc new file mode 100644 index 0000000000..b20ad53ef6 --- /dev/null +++ b/tensorflow/compiler/jit/defs.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/defs.h" + +namespace tensorflow { + +const char* const kXlaCompileAttr = "_XlaCompile"; + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h new file mode 100644 index 0000000000..ddc830cb77 --- /dev/null +++ b/tensorflow/compiler/jit/defs.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides definitions needed for use of the TensorFlow XLA +// device. + +#ifndef TENSORFLOW_COMPILER_JIT_DEFS_H_ +#define TENSORFLOW_COMPILER_JIT_DEFS_H_ + +namespace tensorflow { + +// Name of attribute used to tag operators for compilation with XLA +extern const char* const kXlaCompileAttr; // "_XlaCompile" + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc new file mode 100644 index 0000000000..72c440abe8 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -0,0 +1,660 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" + +#include +#include + +#include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; +const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; + +namespace { + +// A node/slot pair. +// TODO(phawkins): is there a common definition of this? +struct NodeSlot { + NodeSlot() : node(nullptr), slot(-1) {} + NodeSlot(const Node* node, int slot) : node(node), slot(slot) {} + + const Node* node; + int slot; + + bool operator==(const NodeSlot& other) const { + return node == other.node && slot == other.slot; + } + + struct Hasher { + uint64 operator()(NodeSlot const& s) const { + return Hash64Combine(std::hash()(s.node), + std::hash()(s.slot)); + } + }; + + struct PairHasher { + uint64 operator()(std::pair const& s) const { + return Hash64Combine(Hasher()(s.first), Hasher()(s.second)); + } + }; +}; + +class Encapsulator { + public: + Encapsulator(string group_attribute, Graph const* graph_in) + : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} + + // Find subgraphs marked with 'group_attribute', and build a new + // subgraph, one for each value of 'group_attribute'. + Status SplitIntoSubgraphs(); + + // Build a FunctionDef for each subgraph, and add it 'library'. The values of + // the 'group_attribute' annotations become the function names. + // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before + // function conversion. + Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, + FunctionLibraryDefinition* library); + + // Write a copy of the input graph to 'graph_out', where the subgraphs are + // replaced with calls to the new functions. + Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); + + private: + // Returns the key attribute associated with a node. Returns the empty string + // if no key attribute is found. + string GetFunctionNameAttr(const Node* node) const; + + // A subgraph of the input, all marked with a common 'group_attribute' + // value. + struct Subgraph { + // The subgraph extracted from the input graph, suitable for being turned + // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are + // returned by _Retval nodes. + std::unique_ptr graph; + + // Which device are these nodes on? Used both to check that all nodes + // are assigned to the same device, and to assign a device to the call node. + string device; + + // NodeDef for the function call node. + NodeDef call_node_def; + + // Function call node(s) in the output graph. Not owned. + // If parallel_checking is enabled, 'call_node_inputs' is the function call + // node to which inputs should be fed, and 'call_node_outputs' is the + // parallel check op from which outputs should be read. If parallel checking + // is disabled, both point to the function call node. + Node* call_node_inputs; + Node* call_node_outputs; + + // Maps from source (producer node/slot) and destination + // (consumer node/slot) tensors in the input graph to _Arg numbers in + // the subgraph. The source map is one-to-one, whereas the dest map may be + // many-to-one. + std::unordered_map args_by_src; + std::unordered_map args_by_dst; + + // The _Arg nodes in the subgraph, in order by argument number. + std::vector args; + + // Map from source tensor in the input graph to result #. + std::unordered_map results; + }; + + // Builds a ParallelCheck op that compares the output of the original subgraph + // with the encapsulated subgraph. + Status BuildParallelCheckOp( + const std::unordered_map& node_images, + const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op); + + const string group_attribute_; + const Graph* graph_in_; + + std::unordered_map subgraphs_; + + TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); +}; + +// TODO(phawkins) add a canonical copy of these operator names and refactor +// everything to use it. +static const char* const kArgOp = "_Arg"; +static const char* const kRetValOp = "_Retval"; + +// Returns the function name attached to 'node', or the empty string if there is +// none. +string Encapsulator::GetFunctionNameAttr(Node const* node) const { + string attr; + if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) { + attr.clear(); + } + return attr; +} + +Status Encapsulator::SplitIntoSubgraphs() { + Status s; + + // Map from input graph nodes to subgraph nodes. + std::unordered_map node_images; + + // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. + for (Node* node : graph_in_->nodes()) { + if (node->IsSource() || node->IsSink()) continue; + string func_id = GetFunctionNameAttr(node); + if (func_id.empty()) continue; + + Subgraph& subgraph = subgraphs_[func_id]; + if (!subgraph.graph) { + subgraph.graph.reset(new Graph(graph_in_->op_registry())); + subgraph.graph->set_versions(graph_in_->versions()); + } + + Node* image = subgraph.graph->CopyNode(node); + image->ClearAttr(group_attribute_); + node_images[node] = image; + + // Check the device matches any existing device. + string device = node->assigned_device_name().empty() + ? node->def().device() + : node->assigned_device_name(); + + if (subgraph.device.empty()) { + subgraph.device = device; + } else if (subgraph.device != device) { + s.Update(errors::InvalidArgument( + "Mismatched devices for nodes to be grouped by Encapsulator")); + } + } + + // Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for + // data edges that cross subgraph boundaries. + for (const Edge* edge : graph_in_->edges()) { + string src_func_id = GetFunctionNameAttr(edge->src()); + string dst_func_id = GetFunctionNameAttr(edge->dst()); + Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); + Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); + + // Copy edges that are local to a subgraph. + if (!src_func_id.empty() && src_func_id == dst_func_id) { + Graph* g = subgraphs_[src_func_id].graph.get(); + if (edge->IsControlEdge()) { + g->AddControlEdge(src_image, dst_image); + } else { + g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); + } + continue; + } + + // Ignore cross-boundary control edges for right now. We will lift them + // onto the enclosing call operators in BuildOutputGraph(). + if (edge->IsControlEdge()) continue; + + // Add 'src' as an output of its subgraph, if applicable. + if (!src_func_id.empty()) { + Subgraph& src_subgraph = subgraphs_[src_func_id]; + int ret_index = src_subgraph.results.size(); + if (src_subgraph.results + .emplace(NodeSlot(edge->src(), edge->src_output()), ret_index) + .second) { + // Create a new _Retval node + DataType dtype = edge->src()->output_type(edge->src_output()); + + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(src_subgraph.graph->NewName("output")); + AddNodeAttr("T", dtype, &ret_def); + AddNodeAttr("index", ret_index, &ret_def); + Node* ret = src_subgraph.graph->AddNode(ret_def, &s); + if (!s.ok()) return s; + + // Add an edge from 'src' to _Retval. + src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0); + } + } + + // Add 'dst' as an input of its subgraph, if applicable. + if (!dst_func_id.empty()) { + Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + + // Create an _Arg node for this tensor, if none exists yet. + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace( + NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size()); + int arg_index = iter->second; + if (inserted) { + // This is the first time we have seen this tensor. Create an _Arg node. + DataType dtype = edge->dst()->input_type(edge->dst_input()); + + NodeDef arg_def; + NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp); + builder.Attr("T", dtype); + builder.Attr("index", arg_index); + s = builder.Finalize(&arg_def); + if (!s.ok()) return s; + + Node* arg = dst_subgraph.graph->AddNode(arg_def, &s); + if (!s.ok()) return s; + + dst_subgraph.args.push_back(arg); + } + // Add an edge from the _Arg node to 'dst' in the subgraph. + dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + arg_index; + dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image, + edge->dst_input()); + } + } + + for (auto& entry : subgraphs_) { + FixupSourceAndSinkEdges(entry.second.graph.get()); + } + + return s; +} + +Status Encapsulator::BuildFunctionDefs( + const RewriteSubgraphFn& rewrite_subgraph_fn, + FunctionLibraryDefinition* library) { + // For each subgraph, build a FunctionDef. + for (auto& subgraph_entry : subgraphs_) { + const string& name = subgraph_entry.first; + Subgraph& subgraph = subgraph_entry.second; + + subgraph.call_node_def.set_op(name); + subgraph.call_node_def.set_name(name); + subgraph.call_node_def.set_device(subgraph.device); + + if (rewrite_subgraph_fn) { + // Initialize the input and output permutations to the identity. + std::vector input_permutation(subgraph.args_by_src.size()); + std::iota(input_permutation.begin(), input_permutation.end(), 0); + std::vector output_permutation(subgraph.results.size()); + std::iota(output_permutation.begin(), output_permutation.end(), 0); + + TF_RETURN_IF_ERROR( + rewrite_subgraph_fn(&subgraph.graph, &input_permutation, + &output_permutation, &subgraph.call_node_def)); + + // Apply the input/output permutations to the 'args_by_...' and 'results' + // mappings in 'subgraph', so when we build edges in BuildOutputGraph() we + // connect them to the right input/output positions. + if (input_permutation.size() != subgraph.args_by_src.size()) { + return errors::InvalidArgument("Input permutation has incorrect size."); + } + if (output_permutation.size() != subgraph.results.size()) { + return errors::InvalidArgument( + "Output permutation has incorrect size."); + } + for (auto& arg : subgraph.args_by_src) { + arg.second = input_permutation[arg.second]; + } + for (auto& arg : subgraph.args_by_dst) { + arg.second = input_permutation[arg.second]; + } + for (auto& result : subgraph.results) { + result.second = output_permutation[result.second]; + } + } + + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef)); + + if (VLOG_IS_ON(1)) { + VLOG(2) << "Build function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph, + library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("encapsulate_fdef_", name), fdef); + } + + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } + return Status::OK(); +} + +Status Encapsulator::BuildParallelCheckOp( + const std::unordered_map& node_images, + const Encapsulator::Subgraph& subgraph, Graph* graph_out, + Node** parallel_check_op) { + // Build an index mapping output positions to node/slot pairs in the + // original graph. + std::vector results_by_num(subgraph.results.size()); + for (const auto& entry : subgraph.results) { + results_by_num[entry.second] = entry.first; + } + + // Build a parallel check NodeDef. + int num_results = results_by_num.size(); + std::vector result_dtypes(num_results); + std::vector expected_outputs(num_results); + std::vector actual_outputs(num_results); + for (int i = 0; i < num_results; ++i) { + const NodeSlot& node_slot = results_by_num[i]; + result_dtypes[i] = node_slot.node->output_type(node_slot.slot); + expected_outputs[i] = + NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), + node_slot.slot, result_dtypes[i]); + actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(), + i, result_dtypes[i]); + } + // Assign the parallel check op to a CPU on the same task as the cluster it is + // checking. + string device, dummy; + if (!DeviceNameUtils::SplitDeviceName( + subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) { + return errors::InvalidArgument("Could not parse device name"); + } + strings::StrAppend(&device, "/cpu:0"); + + NodeDef check_def; + TF_RETURN_IF_ERROR( + NodeDefBuilder(graph_out->NewName(strings::StrCat( + subgraph.call_node_def.name(), "_parallel_check")), + "ParallelCheck") + .Device(device) + .Attr("T", result_dtypes) + .Input(expected_outputs) + .Input(actual_outputs) + .Finalize(&check_def)); + + Status s; + Node* check_op = graph_out->AddNode(check_def, &s); + if (!s.ok()) return s; + check_op->set_assigned_device_name(device); + + // TODO(phawkins): it seems redundant to call AddEdge as well as + // pass Inputs to the NodeDefBuilder, but I have been unable to find a + // way to avoid it. + for (int i = 0; i < num_results; ++i) { + const NodeSlot& node_slot = results_by_num[i]; + graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, + i); + graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i); + } + + *parallel_check_op = check_op; + return Status::OK(); +} + +Status Encapsulator::BuildOutputGraph(bool parallel_checking, + Graph* graph_out) { + Status s; + + // Map from nodes in the input graph to nodes in the output graph. + std::unordered_map node_images; + + // Copy all unmarked nodes to the output graph. + for (Node* node : graph_in_->nodes()) { + if (node->IsSource() || node->IsSink()) continue; + string func_id = GetFunctionNameAttr(node); + + // Don't copy nodes that going to be encapsulated, unless parallel checking + // is enabled. + if (!func_id.empty() && !parallel_checking) continue; + + Node* image = graph_out->CopyNode(node); + node_images[node] = image; + } + node_images[graph_in_->source_node()] = graph_out->source_node(); + node_images[graph_in_->sink_node()] = graph_out->sink_node(); + + // Add function call nodes for each subgraph. + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + + subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s); + if (!s.ok()) return s; + + // Copy the assigned device and the key_annotation over. + subgraph.call_node_inputs->set_assigned_device_name(subgraph.device); + subgraph.call_node_outputs = subgraph.call_node_inputs; + + if (parallel_checking) { + TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out, + &subgraph.call_node_outputs)); + } + } + + // Set of edges already added to the output graph, represented as (src, dst) + // pairs. We use the set to deduplicate edges; multiple edges in the input + // graph may map to one edge in the output graph. + std::unordered_set, NodeSlot::PairHasher> + edges_added; + + // Add edges to the graph_out graph. + for (const Edge* edge : graph_in_->edges()) { + string src_func_id = GetFunctionNameAttr(edge->src()); + string dst_func_id = GetFunctionNameAttr(edge->dst()); + + // Ignore edges that are strictly contained within one subgraph, unless + // we are constructing parallel check graphs. + if (!src_func_id.empty() && src_func_id == dst_func_id) { + if (parallel_checking) { + Node* src_image = node_images.at(edge->src()); + Node* dst_image = node_images.at(edge->dst()); + if (edge->IsControlEdge()) { + graph_out->AddControlEdge(src_image, dst_image); + } else { + graph_out->AddEdge(src_image, edge->src_output(), dst_image, + edge->dst_input()); + } + } + continue; + } + + // We have an edge that crosses a cluster boundary. + Node* src_image = src_func_id.empty() + ? node_images.at(edge->src()) + : subgraphs_.at(src_func_id).call_node_outputs; + Node* dst_image = dst_func_id.empty() + ? node_images.at(edge->dst()) + : subgraphs_.at(dst_func_id).call_node_inputs; + + // Copy control edges. Lift control edges onto the enclosing call operator. + if (edge->IsControlEdge()) { + // Add the control edge, if we have not already added it. + if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + .second) { + graph_out->AddControlEdge(src_image, dst_image); + } + + // If parallel checking is enabled, also add a control edge to the + // corresponding parallel check op. + if (parallel_checking) { + graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); + } + continue; + } + + int src_output = edge->src_output(); + if (!src_func_id.empty()) { + // 'src' is in a subgraph. Use the corresponding call output instead. + const Subgraph& src_subgraph = subgraphs_.at(src_func_id); + src_output = + src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output())); + } + + int dst_input = edge->dst_input(); + + if (!dst_func_id.empty()) { + // 'dst' is in a subgraph. Use the corresponding call input instead. + const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); + dst_input = + dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); + + // If we are parallel checking, also feed the tensor as an input to the + // corresponding parallel check subgraph. + if (parallel_checking) { + graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), + edge->dst_input()); + } + } + // Add the edge, if we have not already added it. + if (edges_added + .emplace(NodeSlot(src_image, src_output), + NodeSlot(dst_image, dst_input)) + .second) { + graph_out->AddEdge(src_image, src_output, dst_image, dst_input); + } + } + + return s; +} + +} // anonymous namespace + +Status EncapsulateSubgraphsInFunctions( + string group_attribute, const Graph& graph_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + Status s; + + Encapsulator encapsulator(std::move(group_attribute), &graph_in); + s = encapsulator.SplitIntoSubgraphs(); + if (!s.ok()) return s; + + s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library); + if (!s.ok()) return s; + + std::unique_ptr out(new Graph(library)); + out->set_versions(graph_in.versions()); + s = encapsulator.BuildOutputGraph(parallel_checking, out.get()); + if (!s.ok()) return s; + + *graph_out = std::move(out); + return s; +} + +// Renumber the indices of _Arg nodes in a graph, according to +// 'permutation' that maps old indices to new indices. +static Status RenumberArguments(Graph* graph, + const std::vector& permutation) { + for (Node* n : graph->nodes()) { + if (n->type_string() == kArgOp) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + if (index < 0 || index >= permutation.size()) { + return errors::InvalidArgument("Invalid argument number"); + } + n->AddAttr("index", permutation[index]); + } + } + return Status::OK(); +} + +Status EncapsulateSubgraphsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateSubgraphsPass::Run"; + legacy_flags::EncapsulateSubgraphsPassFlags* flags = + legacy_flags::GetEncapsulateSubgraphsPassFlags(); + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, + options.flib_def); + } + + std::unique_ptr graph_out; + FunctionLibraryDefinition* const library = options.flib_def; + + OptimizerOptions opts; + std::unique_ptr flr( + NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr, + TF_GRAPH_DEF_VERSION, library, opts)); + + auto rewrite_subgraph = [&flr]( + std::unique_ptr* subgraph, std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node) { + // Optimize the subgraph. + Graph* g = subgraph->release(); + OptimizeGraph(flr.get(), &g); + subgraph->reset(g); + + std::vector const_args(input_permutation->size()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*g, &const_args)); + + // Compute a permutation of the arguments such that the constant arguments + // are first. + const int num_consts = + std::count(const_args.begin(), const_args.end(), true); + int const_pos = 0; + int arg_pos = num_consts; + for (int i = 0; i < const_args.size(); ++i) { + if (const_args[i]) { + (*input_permutation)[i] = const_pos; + ++const_pos; + } else { + (*input_permutation)[i] = arg_pos; + ++arg_pos; + } + } + + // Renumber argument nodes in the graph. + TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation)); + + // TODO(phawkins): add a forward is-constant analysis, similarly split + // outputs into host-memory constants and device-memory non-constants. + + AddNodeAttr(kXlaCompiledKernelAttr, true, node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); + return Status::OK(); + }; + + TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, **options.graph, rewrite_subgraph, + flags->tf_xla_parallel_checking, &graph_out, library)); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, + options.flib_def); + } + + *options.graph = std::move(graph_out); + return Status::OK(); +} + +bool IsXlaCompiledKernel(const Node& node) { + bool is_compiled = false; + bool has_compilation_attr = + GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() && + is_compiled; + return has_compilation_attr ? is_compiled : false; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h new file mode 100644 index 0000000000..ffd39f0b77 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An optimization pass that groups nodes marked with a common +// kXlaClusterAttr into functions, and replaces the original nodes by +// calls. The calls are annotated with kXlaCompiledKernelAttr. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A rewriting function to apply to each subgraph during encapsulation. +// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; +// 'input_permutation' is a mapping from old argument numbers to new argument +// numbers, whereas 'output_permutation' is the same for outputs. Both +// 'input_permutation' and 'output_permutation' are initialized to the identity +// permutation. 'nodedef' is the NodeDef for the call to the function under +// construction, provided to allow additional attributes to be set. +typedef std::function* graph, std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node_def)> + RewriteSubgraphFn; + +// Transformation that finds subgraphs whose nodes are marked with +// 'group_attribute', splits those subgraphs into functions, and replaces +// the originals with function calls. +// +// 'group_attribute' must be a string valued-attribute that names the new +// functions to introduce. +// +// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before +// function conversion. +// +// If 'parallel_checking' is true, the unencapsulated operators are added to the +// output graph, together with a "ParallelCheck" operator, that verifies that +// the original and encapsulated subgraphs produce similar results. +// +// TODO(phawkins): currently, some information in control edges +// is not preserved. Suppose you have A and B in the main +// graph, C and D in a subgraph. B and C have control deps from A, D has control +// dep from B. Originally D must run after C, post-transformation this +// dependency is lost. +Status EncapsulateSubgraphsInFunctions( + string group_attribute, const Graph& graph_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library); + +// The attribute that marks function calls produced by the encapsulate +// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +extern const char* const kXlaCompiledKernelAttr; + +// Does `node` have the kXlaCompiledKernelAttr attribute? +bool IsXlaCompiledKernel(const Node& node); + +// Functions produce by the EncapsulateSubgraphs pass have their arguments +// ordered such that compile-time constant arguments are first in the argument +// order. The functions are annotated with the following attribute giving the +// number of constant arguments. +extern const char* const kXlaNumConstantArgsAttr; + +class EncapsulateSubgraphsPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc new file mode 100644 index 0000000000..c85882e0d7 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -0,0 +1,397 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/graph/equal_graph_def.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, + string* diff) { + // TODO(phawkins) use a more sophisticated equality test. + if (a.DebugString() != b.DebugString()) { + if (diff) { + *diff = strings::StrCat("Definition mismatch for function ", + a.signature().name(), ", expected:\n", + a.DebugString()); + } + return false; + } + return true; +} + +bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, + const FunctionDefLibrary& actual, string* diff) { + std::unordered_map actual_index; + for (const FunctionDef& function : actual.function()) { + actual_index[function.signature().name()] = &function; + } + + for (const FunctionDef& expected_function : expected.function()) { + auto it = actual_index.find(expected_function.signature().name()); + if (it == actual_index.end()) { + if (diff) { + *diff = strings::StrCat("Did not find expected function '", + expected_function.signature().name(), "'"); + } + return false; + } + if (!EqualFunctionDef(expected_function, *it->second, diff)) return false; + actual_index.erase(it); + } + + if (!actual_index.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat("Found unexpected function '", + actual_index.begin()->second->signature().name(), + "'"); + } + return false; + } + + return true; +} + +#define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \ + do { \ + string diff; \ + EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \ + << diff << "\nActual: " << actual.DebugString(); \ + } while (false) + +REGISTER_OP("InputTest").Output("o: float"); + +REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); +REGISTER_OP("BinaryTest") + .Input("a: float") + .Input("b: float") + .Output("o: float"); + +REGISTER_OP("AddNLikeTest") + .Input("inputs: N * T") + .Output("sum: T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .SetIsCommutative() + .SetIsAggregate(); + +Node* Input(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("InputTest", opts); +} + +Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { + return ops::UnaryOp("UnaryTest", a, opts); +} + +Node* Binary(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("BinaryTest", a, b, opts); +} + +Node* AddNLike(std::vector inputs, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest", + opts.op_registry()); + node_builder.Input(inputs); + return opts.FinalizeBuilder(&node_builder); +} + +Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) { + return ops::SourceOp("_Arg", + opts.WithAttr("T", type).WithAttr("index", index)); +} + +Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval", + opts.op_registry()); + node_builder.Input(a).Attr("index", index); + return opts.FinalizeBuilder(&node_builder); +} + +Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { + Status s; + // Convert the GraphDef to a Graph + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), *library)); + GraphConstructorOptions options; + options.allow_internal_ops = true; + std::unique_ptr graph(new Graph(lib_def.get())); + s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); + if (!s.ok()) return s; + + std::unique_ptr graph_out; + s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, + /* rewrite_subgraph_fn= */ {}, + /* parallel_checking= */ false, + &graph_out, lib_def.get()); + if (!s.ok()) return s; + + GraphDef graphdef_out; + graph_out->ToGraphDef(&graphdef_out); + graphdef->Swap(&graphdef_out); + + *library = lib_def->ToProto(); + return s; +} + +// If there are no marked nodes, funcification should be a no-op. +TEST(EncapsulateSubgraphsTest, NoFunctions) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + + Node* a = Input(builder.opts().WithName("A")); + Node* b = Input(builder.opts().WithName("B")); + Node* c = Unary(a, builder.opts().WithName("C")); + Binary(b, c, builder.opts().WithName("D")); + + GraphDef graphdef_in; + FunctionDefLibrary library_in; + builder.ToGraphDef(&graphdef_in); + *library_in.add_function() = test::function::XTimesTwo(); + + GraphDef graphdef_out = graphdef_in; + FunctionDefLibrary library_out = library_in; + TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out)); + + TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out); +} + +// Test with one function to transform. +TEST(EncapsulateSubgraphsTest, OneFunction) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Binary(a, d, b1.opts().WithName("E")); + b1.ToGraphDef(&graphdef); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + { + {{"C"}, "UnaryTest", {"input__0"}}, + {{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}}, + }, + {{"output__2", "c:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Binary(a, call, b2.opts().WithName("E")); + b2.ToGraphDef(&graphdef_expected); + } + + // If there are no marked nodes, funcification should be a no-op. + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two functions to transform. +TEST(EncapsulateSubgraphsTest, TwoFunctions) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* control = Input(b1.opts().WithName("Control")); + Node* c = + Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( + "_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( + "_encapsulate", "F2")); + Binary(a, d, b1.opts().WithName("E")); + b1.ToGraphDef(&graphdef); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"input__0:float"}, {"output__1:float"}, {}, + { + {{"C"}, "UnaryTest", {"input__0"}}, + }, + {{"output__1", "C:o:0"}}); + *library_expected.add_function() = FunctionDefHelper::Create( + "F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + { + {{"D"}, "BinaryTest", {"input__0", "input__1"}}, + }, + {{"output__2", "D:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + Node* control = Input(b2.opts().WithName("Control")); + + NodeBuilder nb("F1", "F1", lib_def.get()); + nb.Input(a).ControlInput(control); + Node* call1 = b2.opts().FinalizeBuilder(&nb); + + NodeBuilder nb2("F2", "F2", lib_def.get()); + nb2.Input(b).Input(call1).ControlInput(control); + Node* call2 = b2.opts().FinalizeBuilder(&nb2); + + Binary(a, call2, b2.opts().WithName("E")); + b2.ToGraphDef(&graphdef_expected); + } + + // If there are no marked nodes, funcification should be a no-op. + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Returns a vector of node names in 'graph', sorted by name. +std::vector GraphNodes(const Graph& graph) { + std::vector nodes; + for (const auto& node : graph.nodes()) { + if (!node->IsSource() && !node->IsSink()) { + nodes.push_back(node->name()); + } + } + std::sort(nodes.begin(), nodes.end()); + return nodes; +} + +// Returns a sorted vector of (src, dst) edges in 'graph'. +std::vector> GraphEdges(const Graph& graph) { + std::vector> edges; + for (const Edge* edge : graph.edges()) { + if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; + edges.emplace_back( + strings::StrCat(edge->src()->name(), ":", edge->src_output()), + strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); + } + std::sort(edges.begin(), edges.end()); + return edges; +} + +TEST(EncapsulateSubgraphsTest, InputDeduplication) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + auto add1 = ops::Add(root.WithOpName("add1"), x, x); + add1.node()->AddAttr("_cluster", "cluster1"); + auto add2 = ops::Add(root.WithOpName("add2"), add1, add1); + add2.node()->AddAttr("_cluster", "cluster2"); + auto out = ops::Mul(root.WithOpName("mul"), add1, add2); + + Graph graph_before_encapsulation(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + std::unique_ptr graph; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, + /*parallel_checking=*/false, &graph, &library)); + + std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; + EXPECT_EQ(expected_nodes, GraphNodes(*graph)); + + std::vector> expected_edges = { + {"cluster1:0", "cluster2:0"}, + {"cluster1:0", "mul:0"}, + {"cluster2:0", "mul:1"}, + {"x:0", "cluster1:0"}}; + EXPECT_EQ(expected_edges, GraphEdges(*graph)); +} + +TEST(EncapsulateSubgraphsTest, ParallelChecking) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); + add1.node()->AddAttr("_cluster", "cluster1"); + auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); + add2.node()->AddAttr("_cluster", "cluster1"); + auto out = ops::Mul(root.WithOpName("mul"), x1, add2); + + Graph graph_before_encapsulation(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + std::unique_ptr graph; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, + /*parallel_checking=*/true, &graph, &library)); + + std::vector expected_nodes = { + "add1", "add2", "cluster1", "cluster1_parallel_check/_0", + "mul", "x1", "x2"}; + EXPECT_EQ(expected_nodes, GraphNodes(*graph)); + + std::vector> expected_edges = { + {"add1:0", "add2:0"}, + {"add2:0", "cluster1_parallel_check/_0:0"}, + {"cluster1:0", "cluster1_parallel_check/_0:1"}, + {"cluster1_parallel_check/_0:0", "mul:1"}, + {"x1:0", "add1:0"}, + {"x1:0", "cluster1:0"}, + {"x1:0", "mul:0"}, + {"x2:0", "add1:1"}, + {"x2:0", "add2:1"}, + {"x2:0", "cluster1:1"}, + }; + EXPECT_EQ(expected_edges, GraphEdges(*graph)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc new file mode 100644 index 0000000000..a589a94fd4 --- /dev/null +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -0,0 +1,274 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/graph_to_functiondef.h" + +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins) add a canonical copy of these operator names and refactor +// everything to use it. +const char* const kArgOp = "_Arg"; +const char* const kRetValOp = "_Retval"; + +// Class that maintains a one-to-one original node name -> new name mapping. +// We have to normalize the names used as input and output arguments to +// match regexp "[a-z][a-z0-9_]*". Once we rename them, we risk creating +// a name collision with the other node names, so if necessary we add +// a suffix to make names unique. So if we have an input named "A" and a +// node in the function body named "a", they will be renamed to "a" and "a_0". +class NodeNameMapping { + public: + NodeNameMapping() = default; + + // Normalize the input/output name and then make it unique. + string Normalize(const string& name); + + // Make the node name unique. + string Uniquify(const string& name); + + // Look up how a node name was previously normalized/uniquified. + // Returns empty if name was never seen. + string Renormalize(const string& name) const; + + private: + string NormalizeHelper(string name) const; + string UniquifyHelper(string name); + + std::unordered_set used_names_; + std::unordered_map name_mapping_; +}; + +string NodeNameMapping::NormalizeHelper(string name) const { + // Convert letters to lowercase and non-alphanumeric characters to '_'. + if (name.empty()) name = "unknown"; + const int n = name.size(); + for (int i = 0; i < n; i++) { + char c = name[i]; + if (isalnum(c)) { + if (isupper(c)) { + name[i] = tolower(c); + } + } else { + name[i] = '_'; + } + } + return name; +} + +string NodeNameMapping::UniquifyHelper(string name) { + // If the name hasn't been used yet, use it as-is. + if (used_names_.insert(name).second) return name; + // Add a suffix to name to make it unique. + for (int i = 0;; ++i) { + const string candidate = strings::StrCat(name, "_", i); + if (used_names_.insert(candidate).second) return candidate; + } +} + +string NodeNameMapping::Normalize(const string& name) { + const string normalized = UniquifyHelper(NormalizeHelper(name)); + name_mapping_[name] = normalized; + return normalized; +} + +string NodeNameMapping::Uniquify(const string& name) { + const string uniqued = UniquifyHelper(name); + name_mapping_[name] = uniqued; + return uniqued; +} + +string NodeNameMapping::Renormalize(const string& name) const { + const auto iter = name_mapping_.find(name); + if (iter == name_mapping_.end()) return string(); + return iter->second; +} + +} // anonymous namespace + +// Graph to FunctionDef conversion. This code is closely modeled on the Python +// code in third_party/tensorflow/python/framework/function.py. + +Status GraphToFunctionDef(const Graph& graph, const string& name, + FunctionDef* fdef) { + fdef->mutable_signature()->set_name(name); + + std::unordered_map tensor_renaming; + std::unordered_map return_values; + NodeNameMapping node_names; + + for (Node const* node : graph.nodes()) { + if (!node->IsOp()) continue; + + if (node->type_string() == kArgOp) { + int index; + DataType type; + GetNodeAttr(node->def(), "T", &type); + GetNodeAttr(node->def(), "index", &index); + while (fdef->signature().input_arg_size() <= index) { + fdef->mutable_signature()->add_input_arg(); + } + OpDef::ArgDef* argdef = + fdef->mutable_signature()->mutable_input_arg(index); + argdef->set_type(type); + const string normalized = node_names.Normalize(node->name()); + argdef->set_name(normalized); + tensor_renaming[strings::StrCat(node->name(), ":0")] = normalized; + continue; + } + + if (node->type_string() == kRetValOp) { + int index; + DataType type; + GetNodeAttr(node->def(), "T", &type); + GetNodeAttr(node->def(), "index", &index); + while (fdef->signature().output_arg_size() <= index) { + fdef->mutable_signature()->add_output_arg(); + } + OpDef::ArgDef* argdef = + fdef->mutable_signature()->mutable_output_arg(index); + argdef->set_type(type); + const string normalized = node_names.Normalize(node->name()); + argdef->set_name(normalized); + CHECK_EQ(node->in_edges().size(), 1); + Edge const* edge = *node->in_edges().begin(); + return_values[normalized] = + strings::StrCat(edge->src()->name(), ":", edge->src_output()); + continue; + } + + NodeDef* node_def = fdef->add_node_def(); + node_def->CopyFrom(node->def()); + node_def->set_name(node_names.Uniquify(node->name())); + node_def->clear_device(); + + // Reset input names based on graph rather than the NodeDef. + node_def->clear_input(); + + // Edges, indexed by dst_input. + std::vector in_edges; + std::vector control_edges; + for (Edge const* edge : node->in_edges()) { + if (edge->src()->IsSource()) continue; + + if (edge->IsControlEdge()) { + control_edges.push_back(edge); + } else { + if (in_edges.size() <= edge->dst_input()) { + in_edges.resize(edge->dst_input() + 1); + } + in_edges[edge->dst_input()] = edge; + } + } + + // Add regular inputs + for (int i = 0; i < in_edges.size(); ++i) { + const Edge* edge = in_edges[i]; + if (edge == nullptr) { + return errors::InvalidArgument( + "Nonconsecutive input edges; missing " + "input edge ", + i, " for node ", node->name()); + } + node_def->add_input( + strings::StrCat(edge->src()->name(), ":", edge->src_output())); + } + + // Add control inputs + for (const Edge* edge : control_edges) { + node_def->add_input(strings::StrCat("^", edge->src()->name())); + } + + // Populate tensor_renaming. + NameRangeMap output_ranges; + TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr, + &output_ranges)); + for (const auto& output : output_ranges) { + for (int i = output.second.first; i < output.second.second; ++i) { + const string tensor_name = strings::StrCat( + node_def->name(), ":", output.first, ":", i - output.second.first); + tensor_renaming[strings::StrCat(node->name(), ":", i)] = tensor_name; + } + } + } + + // Detect missing function inputs. + for (int i = 0; i < fdef->signature().input_arg_size(); ++i) { + const string& input_name = fdef->signature().input_arg(i).name(); + if (input_name.empty()) { + return errors::InvalidArgument("Missing input ", i, " to function ", + name); + } + } + + // Remap input names. We do this as a second pass to allow the nodes to be in + // any order. + for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) { + NodeDef* node_def = fdef->mutable_node_def(n_index); + for (int i = 0; i < node_def->input_size(); ++i) { + if (StringPiece(node_def->input(i)).starts_with("^")) { + // Control input + const string normalized = + node_names.Renormalize(node_def->input(i).substr(1)); + if (normalized.empty()) { + return errors::InvalidArgument( + "Could not remap control input ", i, ", '", node_def->input(i), + "', of node '", node_def->name(), "' in function ", name); + } + *node_def->mutable_input(i) = strings::StrCat("^", normalized); + } else { + const auto iter = tensor_renaming.find(node_def->input(i)); + if (iter == tensor_renaming.end()) { + return errors::InvalidArgument( + "Could not remap input ", i, ", '", node_def->input(i), + "', of node '", node_def->name(), "' in function ", name); + } + *node_def->mutable_input(i) = iter->second; + } + } + } + + // Remap return values. + for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { + const string& ret_name = fdef->signature().output_arg(r).name(); + if (ret_name.empty()) { + return errors::InvalidArgument("Missing output ", r, " to function ", + name); + } + const string& return_value = return_values[ret_name]; + const auto iter = tensor_renaming.find(return_value); + if (iter == tensor_renaming.end()) { + return errors::InvalidArgument("Could not remap return value ", r, ", '", + ret_name, "', of '", return_value, + "' in function ", name); + } + (*fdef->mutable_ret())[ret_name] = iter->second; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/graph_to_functiondef.h b/tensorflow/compiler/jit/graph_to_functiondef.h new file mode 100644 index 0000000000..3e1ae7bbbe --- /dev/null +++ b/tensorflow/compiler/jit/graph_to_functiondef.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_ +#define TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Converts 'graph' to a FunctionDef 'fdef', with name 'name'. +// Closely modeled on the Python code in +// third_party/tensorflow/python/framework/function.py +Status GraphToFunctionDef(const Graph& graph, const string& name, + FunctionDef* fdef); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_ diff --git a/tensorflow/compiler/jit/graph_to_functiondef_test.cc b/tensorflow/compiler/jit/graph_to_functiondef_test.cc new file mode 100644 index 0000000000..df45f455a9 --- /dev/null +++ b/tensorflow/compiler/jit/graph_to_functiondef_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/graph_to_functiondef.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/graph/equal_graph_def.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, + string* diff) { + // TODO(phawkins) use a more sophisticated equality test. + if (a.DebugString() != b.DebugString()) { + if (diff) { + *diff = strings::StrCat("Definition mismatch for function ", + a.signature().name(), ":\n", a.DebugString(), + "\n ---- vs. ----\n", b.DebugString()); + } + return false; + } + return true; +} + +TEST(GraphToFunctionDefTest, Basics) { + Scope root = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); + auto b = ops::_Arg(root.WithOpName("B"), DT_FLOAT, 1); + auto c = ops::_Arg(root.WithOpName("C"), DT_FLOAT, 2); + auto d = ops::Add(root.WithOpName("D"), a, b); + auto e = ops::Add(root.WithOpName("b"), d, c); + auto f = ops::Neg(root.WithOpName("h"), e); + auto g = + ops::AddN(root.WithOpName("G"), std::initializer_list{e, f}); + auto h = ops::_Retval(root.WithOpName("H"), g, 0); + + GraphDef graph_def; + root.ToGraphDef(&graph_def); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphConstructorOptions options; + TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, graph.get())); + + FunctionDef fdef; + TF_EXPECT_OK(GraphToFunctionDef(*graph, "test_fn", &fdef)); + + FunctionDef fdef_expected = FunctionDefHelper::Create( + "test_fn", // function name + {"a: float", "b: float", "c: float"}, // inputs + {"h_0: float"}, // outputs + {}, // attrs + { + // nodes in the function body + {{"D"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}}, + {{"b_0"}, "Add", {"D:z:0", "c"}, {{"T", DT_FLOAT}}}, + {{"h"}, "Neg", {"b_0:z:0"}, {{"T", DT_FLOAT}}}, + {{"G"}, "AddN", {"b_0:z:0", "h:y:0"}, {{"N", 2}, {"T", DT_FLOAT}}}, + }, + {{"h_0", "G:sum:0"}}); // return values + + string diff; + bool fdefs_equal = EqualFunctionDef(fdef_expected, fdef, &diff); + EXPECT_TRUE(fdefs_equal) << diff; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD new file mode 100644 index 0000000000..ce634529cb --- /dev/null +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -0,0 +1,41 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], +) + +cc_library( + name = "graphcycles", + srcs = ["graphcycles.cc"], + hdrs = ["graphcycles.h"], + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "graphcycles_test", + srcs = ["graphcycles_test.cc"], + deps = [ + ":graphcycles", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc new file mode 100644 index 0000000000..87d5de09d1 --- /dev/null +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -0,0 +1,391 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// GraphCycles provides incremental cycle detection on a dynamic +// graph using the following algorithm: +// +// A dynamic topological sort algorithm for directed acyclic graphs +// David J. Pearce, Paul H. J. Kelly +// Journal of Experimental Algorithmics (JEA) JEA Homepage archive +// Volume 11, 2006, Article No. 1.7 +// +// Brief summary of the algorithm: +// +// (1) Maintain a rank for each node that is consistent +// with the topological sort of the graph. I.e., path from x to y +// implies rank[x] < rank[y]. +// (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y]. +// (3) Otherwise: adjust ranks in the neighborhood of x and y. + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" + +#include +#include + +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +typedef std::unordered_set NodeSet; +template +struct VecStruct { + typedef gtl::InlinedVector type; +}; +template +using Vec = typename VecStruct::type; + +struct Node { + Node() : in(4), out(4) {} // Small hashtables for in/out edges + + int32 rank; // rank number assigned by Pearce-Kelly algorithm + bool visited; // Temporary marker used by depth-first-search + void* data; // User-supplied data + NodeSet in; // List of immediate predecessor nodes in graph + NodeSet out; // List of immediate successor nodes in graph +}; + +} // namespace + +struct GraphCycles::Rep { + Vec nodes_; + Vec free_nodes_; // Indices for unused entries in nodes_ + + // Temporary state. + Vec deltaf_; // Results of forward DFS + Vec deltab_; // Results of backward DFS + Vec list_; // All nodes to reprocess + Vec merged_; // Rank values to assign to list_ entries + Vec stack_; // Emulates recursion stack when doing depth first search +}; + +GraphCycles::GraphCycles() : rep_(new Rep) {} + +GraphCycles::~GraphCycles() { + for (int i = 0; i < rep_->nodes_.size(); i++) { + delete rep_->nodes_[i]; + } + delete rep_; +} + +bool GraphCycles::CheckInvariants() const { + Rep* r = rep_; + NodeSet ranks; // Set of ranks seen so far. + for (int32 x = 0; x < r->nodes_.size(); x++) { + Node* nx = r->nodes_[x]; + if (nx->visited) { + LOG(FATAL) << "Did not clear visited marker on node " << x; + } + if (!ranks.insert(nx->rank).second) { + LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank; + } + for (auto y : nx->out) { + Node* ny = r->nodes_[y]; + if (nx->rank >= ny->rank) { + LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment " + << nx->rank << "->" << ny->rank; + } + } + } + return true; +} + +int32 GraphCycles::NewNode() { + if (rep_->free_nodes_.empty()) { + Node* n = new Node; + n->visited = false; + n->data = NULL; + n->rank = rep_->nodes_.size(); + rep_->nodes_.push_back(n); + return n->rank; + } else { + // Preserve preceding rank since the set of ranks in use must be + // a permutation of [0,rep_->nodes_.size()-1]. + int32 r = rep_->free_nodes_.back(); + rep_->nodes_[r]->data = NULL; + rep_->free_nodes_.pop_back(); + return r; + } +} + +void GraphCycles::RemoveNode(int32 node) { + Node* x = rep_->nodes_[node]; + for (auto y : x->out) { + rep_->nodes_[y]->in.erase(node); + } + for (auto y : x->in) { + rep_->nodes_[y]->out.erase(node); + } + x->in.clear(); + x->out.clear(); + rep_->free_nodes_.push_back(node); +} + +void* GraphCycles::GetNodeData(int32 node) const { + return rep_->nodes_[node]->data; +} + +void GraphCycles::SetNodeData(int32 node, void* data) { + rep_->nodes_[node]->data = data; +} + +bool GraphCycles::HasEdge(int32 x, int32 y) const { + return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end(); +} + +void GraphCycles::RemoveEdge(int32 x, int32 y) { + rep_->nodes_[x]->out.erase(y); + rep_->nodes_[y]->in.erase(x); + // No need to update the rank assignment since a previous valid + // rank assignment remains valid after an edge deletion. +} + +static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound); +static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound); +static void Reorder(GraphCycles::Rep* r); +static void Sort(const Vec&, Vec* delta); +static void MoveToList(GraphCycles::Rep* r, Vec* src, Vec* dst); +static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes); + +bool GraphCycles::InsertEdge(int32 x, int32 y) { + if (x == y) return false; + Rep* r = rep_; + Node* nx = r->nodes_[x]; + if (!nx->out.insert(y).second) { + // Edge already exists. + return true; + } + + Node* ny = r->nodes_[y]; + ny->in.insert(x); + + if (nx->rank <= ny->rank) { + // New edge is consistent with existing rank assignment. + return true; + } + + // Current rank assignments are incompatible with the new edge. Recompute. + // We only need to consider nodes that fall in the range [ny->rank,nx->rank]. + if (!ForwardDFS(r, y, nx->rank)) { + // Found a cycle. Undo the insertion and tell caller. + nx->out.erase(y); + ny->in.erase(x); + // Since we do not call Reorder() on this path, clear any visited + // markers left by ForwardDFS. + ClearVisitedBits(r, r->deltaf_); + return false; + } + BackwardDFS(r, x, ny->rank); + Reorder(r); + return true; +} + +static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) { + // Avoid recursion since stack space might be limited. + // We instead keep a stack of nodes to visit. + r->deltaf_.clear(); + r->stack_.clear(); + r->stack_.push_back(n); + while (!r->stack_.empty()) { + n = r->stack_.back(); + r->stack_.pop_back(); + Node* nn = r->nodes_[n]; + if (nn->visited) continue; + + nn->visited = true; + r->deltaf_.push_back(n); + + for (auto w : nn->out) { + Node* nw = r->nodes_[w]; + if (nw->rank == upper_bound) { + return false; // Cycle + } + if (!nw->visited && nw->rank < upper_bound) { + r->stack_.push_back(w); + } + } + } + return true; +} + +static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) { + r->deltab_.clear(); + r->stack_.clear(); + r->stack_.push_back(n); + while (!r->stack_.empty()) { + n = r->stack_.back(); + r->stack_.pop_back(); + Node* nn = r->nodes_[n]; + if (nn->visited) continue; + + nn->visited = true; + r->deltab_.push_back(n); + + for (auto w : nn->in) { + Node* nw = r->nodes_[w]; + if (!nw->visited && lower_bound < nw->rank) { + r->stack_.push_back(w); + } + } + } +} + +static void Reorder(GraphCycles::Rep* r) { + Sort(r->nodes_, &r->deltab_); + Sort(r->nodes_, &r->deltaf_); + + // Adds contents of delta lists to list_ (backwards deltas first). + r->list_.clear(); + MoveToList(r, &r->deltab_, &r->list_); + MoveToList(r, &r->deltaf_, &r->list_); + + // Produce sorted list of all ranks that will be reassigned. + r->merged_.resize(r->deltab_.size() + r->deltaf_.size()); + std::merge(r->deltab_.begin(), r->deltab_.end(), r->deltaf_.begin(), + r->deltaf_.end(), r->merged_.begin()); + + // Assign the ranks in order to the collected list. + for (int32 i = 0; i < r->list_.size(); i++) { + r->nodes_[r->list_[i]]->rank = r->merged_[i]; + } +} + +static void Sort(const Vec& nodes, Vec* delta) { + struct ByRank { + const Vec* nodes; + bool operator()(int32 a, int32 b) const { + return (*nodes)[a]->rank < (*nodes)[b]->rank; + } + }; + ByRank cmp; + cmp.nodes = &nodes; + std::sort(delta->begin(), delta->end(), cmp); +} + +static void MoveToList(GraphCycles::Rep* r, Vec* src, Vec* dst) { + for (int32 i = 0; i < src->size(); i++) { + int32 w = (*src)[i]; + (*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank + r->nodes_[w]->visited = false; // Prepare for future DFS calls + dst->push_back(w); + } +} + +static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes) { + for (int32 i = 0; i < nodes.size(); i++) { + r->nodes_[nodes[i]]->visited = false; + } +} + +int GraphCycles::FindPath(int32 x, int32 y, int max_path_len, + int32 path[]) const { + // Forward depth first search starting at x until we hit y. + // As we descend into a node, we push it onto the path. + // As we leave a node, we remove it from the path. + int path_len = 0; + + Rep* r = rep_; + NodeSet seen; + r->stack_.clear(); + r->stack_.push_back(x); + while (!r->stack_.empty()) { + int32 n = r->stack_.back(); + r->stack_.pop_back(); + if (n < 0) { + // Marker to indicate that we are leaving a node + path_len--; + continue; + } + + if (path_len < max_path_len) { + path[path_len] = n; + } + path_len++; + r->stack_.push_back(-1); // Will remove tentative path entry + + if (n == y) { + return path_len; + } + + for (auto w : r->nodes_[n]->out) { + if (seen.insert(w).second) { + r->stack_.push_back(w); + } + } + } + + return 0; +} + +bool GraphCycles::IsReachable(int32 x, int32 y) const { + return FindPath(x, y, 0, NULL) > 0; +} + +bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { + if (x == y) return true; + Rep* r = rep_; + Node* nx = r->nodes_[x]; + Node* ny = r->nodes_[y]; + + if (nx->rank >= ny->rank) { + // x cannot reach y since it is after it in the topological ordering + return false; + } + + // See if x can reach y using a DFS search that is limited to y's rank + bool reachable = !ForwardDFS(r, x, ny->rank); + + // Clear any visited markers left by ForwardDFS. + ClearVisitedBits(r, r->deltaf_); + return reachable; +} + +bool GraphCycles::ContractEdge(int32 a, int32 b) { + CHECK(HasEdge(a, b)); + RemoveEdge(a, b); + + if (IsReachableNonConst(a, b)) { + // Restore the graph to its original state. + InsertEdge(a, b); + return false; + } + + Node* nb = rep_->nodes_[b]; + std::unordered_set out = std::move(nb->out); + std::unordered_set in = std::move(nb->in); + for (auto y : out) { + rep_->nodes_[y]->in.erase(b); + } + for (auto y : in) { + rep_->nodes_[y]->out.erase(b); + } + rep_->free_nodes_.push_back(b); + + for (auto y : out) { + InsertEdge(a, y); + } + for (auto y : in) { + InsertEdge(y, a); + } + return true; +} + +std::unordered_set GraphCycles::Successors(int32 node) { + return rep_->nodes_[node]->out; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h new file mode 100644 index 0000000000..d11d6e27b1 --- /dev/null +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -0,0 +1,128 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ +#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ + +// GraphCycles detects the introduction of a cycle into a directed +// graph that is being built up incrementally. +// +// Nodes are identified by small integers. It is not possible to +// record multiple edges with the same (source, destination) pair; +// requests to add an edge where one already exists are silently +// ignored. +// +// It is also not possible to introduce a cycle; an attempt to insert +// an edge that would introduce a cycle fails and returns false. +// +// GraphCycles uses no internal locking; calls into it should be +// serialized externally. + +// Performance considerations: +// Works well on sparse graphs, poorly on dense graphs. +// Extra information is maintained incrementally to detect cycles quickly. +// InsertEdge() is very fast when the edge already exists, and reasonably fast +// otherwise. +// FindPath() is linear in the size of the graph. +// The current implementation uses O(|V|+|E|) space. + +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOTE!!! +// For now a copy of this is forked to net/plaque. If you +// find a bug or add a feature, please inform the owners of the +// net/plaque copy in case it should be integrated. +// NOTE!!! +class GraphCycles { + public: + GraphCycles(); + ~GraphCycles(); + + // Allocate an unused node id and return it. + // The new node has a null pointer for its node data. + // All node identifiers passed to other routines in this interface + // must have been allocated by NewNode() and not yet deallocated + // by RemoveNode(). + int32 NewNode(); + + // Remove "node" from the graph, deleting all edges to and from it. + // After this call the identifier "node" it may no longer be used + // as an argument to any routine until it has been reallocated with + // NewNode(). + void RemoveNode(int32 node); + + // Attempt to insert an edge from source_node to dest_node. If the + // edge would introduce a cycle, return false without making any + // changes. Otherwise add the edge and return true. + bool InsertEdge(int32 source_node, int32 dest_node); + + // Remove any edge that exists from source_node to dest_node. + void RemoveEdge(int32 source_node, int32 dest_node); + + // Return whether there is an edge directly from source_node to dest_node. + bool HasEdge(int32 source_node, int32 dest_node) const; + + // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is + // removed from the graph, and edges to/from 'b' are replaced with edges + // to/from 'a'. If contracting the edge would create a cycle, does nothing + // and returns false. + bool ContractEdge(int32 a, int32 b); + + // Return whether dest_node is reachable from source_node + // by following edges. + bool IsReachable(int32 source_node, int32 dest_node) const; + + // A faster non-thread-safe version of IsReachable. + bool IsReachableNonConst(int32 source_node, int32 dest_node); + + // Return or set the node data for a node. This data is unused + // by the implementation. + void *GetNodeData(int32 node) const; + void SetNodeData(int32 node, void *data); + + // Find a path from "source" to "dest". If such a path exists, place the + // node IDs of the nodes on the path in the array path[], and return the + // number of nodes on the path. If the path is longer than max_path_len + // nodes, only the first max_path_len nodes are placed in path[]. The client + // should compare the return value with max_path_len" to see when this + // occurs. If no path exists, return 0. Any valid path stored in path[] + // will start with "source" and end with "dest". There is no guarantee that + // the path is the shortest, but no node will appear twice in the path, + // except the source and destination node if they are identical; therefore, + // the return value is at most one greater than the number of nodes in the + // graph. + int FindPath(int32 source, int32 dest, int max_path_len, int32 path[]) const; + + // Check internal invariants. Crashes on failure, returns true on success. + // Expensive: should only be called from graphcycles_test.cc. + bool CheckInvariants() const; + + std::unordered_set Successors(int32 node); + + // ---------------------------------------------------- + struct Rep; + + private: + Rep *rep_; // opaque representation + TF_DISALLOW_COPY_AND_ASSIGN(GraphCycles); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc new file mode 100644 index 0000000000..f27a616ac9 --- /dev/null +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -0,0 +1,515 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A test for the GraphCycles interface. + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" + +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::string; + +// We emulate a GraphCycles object with a node vector and an edge vector. +// We then compare the two implementations. + +typedef std::vector Nodes; +struct Edge { + int from; + int to; +}; +typedef std::vector Edges; + +// Return whether "to" is reachable from "from". +static bool IsReachable(Edges *edges, int from, int to, + std::unordered_set *seen) { + seen->insert(from); // we are investigating "from"; don't do it again + if (from == to) return true; + for (int i = 0; i != edges->size(); i++) { + Edge *edge = &(*edges)[i]; + if (edge->from == from) { + if (edge->to == to) { // success via edge directly + return true; + } else if (seen->find(edge->to) == seen->end() && // success via edge + IsReachable(edges, edge->to, to, seen)) { + return true; + } + } + } + return false; +} + +static void PrintNodes(Nodes *nodes) { + LOG(INFO) << "NODES (" << nodes->size() << ")"; + for (int i = 0; i != nodes->size(); i++) { + LOG(INFO) << (*nodes)[i]; + } +} + +static void PrintEdges(Edges *edges) { + LOG(INFO) << "EDGES (" << edges->size() << ")"; + for (int i = 0; i != edges->size(); i++) { + int a = (*edges)[i].from; + int b = (*edges)[i].to; + LOG(INFO) << a << " " << b; + } + LOG(INFO) << "---"; +} + +static void PrintGCEdges(Nodes *nodes, tensorflow::GraphCycles *gc) { + LOG(INFO) << "GC EDGES"; + for (int i = 0; i != nodes->size(); i++) { + for (int j = 0; j != nodes->size(); j++) { + int a = (*nodes)[i]; + int b = (*nodes)[j]; + if (gc->HasEdge(a, b)) { + LOG(INFO) << a << " " << b; + } + } + } + LOG(INFO) << "---"; +} + +static void PrintTransitiveClosure(Nodes *nodes, Edges *edges, + tensorflow::GraphCycles *gc) { + LOG(INFO) << "Transitive closure"; + for (int i = 0; i != nodes->size(); i++) { + for (int j = 0; j != nodes->size(); j++) { + int a = (*nodes)[i]; + int b = (*nodes)[j]; + std::unordered_set seen; + if (IsReachable(edges, a, b, &seen)) { + LOG(INFO) << a << " " << b; + } + } + } + LOG(INFO) << "---"; +} + +static void PrintGCTransitiveClosure(Nodes *nodes, + tensorflow::GraphCycles *gc) { + LOG(INFO) << "GC Transitive closure"; + for (int i = 0; i != nodes->size(); i++) { + for (int j = 0; j != nodes->size(); j++) { + int a = (*nodes)[i]; + int b = (*nodes)[j]; + if (gc->IsReachable(a, b)) { + LOG(INFO) << a << " " << b; + } + } + } + LOG(INFO) << "---"; +} + +static void CheckTransitiveClosure(Nodes *nodes, Edges *edges, + tensorflow::GraphCycles *gc) { + std::unordered_set seen; + for (int i = 0; i != nodes->size(); i++) { + for (int j = 0; j != nodes->size(); j++) { + seen.clear(); + int a = (*nodes)[i]; + int b = (*nodes)[j]; + bool gc_reachable = gc->IsReachable(a, b); + CHECK_EQ(gc_reachable, gc->IsReachableNonConst(a, b)); + bool reachable = IsReachable(edges, a, b, &seen); + if (gc_reachable != reachable) { + PrintEdges(edges); + PrintGCEdges(nodes, gc); + PrintTransitiveClosure(nodes, edges, gc); + PrintGCTransitiveClosure(nodes, gc); + LOG(FATAL) << "gc_reachable " << gc_reachable << " reachable " + << reachable << " a " << a << " b " << b; + } + } + } +} + +static void CheckEdges(Nodes *nodes, Edges *edges, + tensorflow::GraphCycles *gc) { + int count = 0; + for (int i = 0; i != edges->size(); i++) { + int a = (*edges)[i].from; + int b = (*edges)[i].to; + if (!gc->HasEdge(a, b)) { + PrintEdges(edges); + PrintGCEdges(nodes, gc); + LOG(FATAL) << "!gc->HasEdge(" << a << ", " << b << ")"; + } + } + for (int i = 0; i != nodes->size(); i++) { + for (int j = 0; j != nodes->size(); j++) { + int a = (*nodes)[i]; + int b = (*nodes)[j]; + if (gc->HasEdge(a, b)) { + count++; + } + } + } + if (count != edges->size()) { + PrintEdges(edges); + PrintGCEdges(nodes, gc); + LOG(FATAL) << "edges->size() " << edges->size() << " count " << count; + } +} + +// Returns the index of a randomly chosen node in *nodes. +// Requires *nodes be non-empty. +static int RandomNode(std::mt19937 *rnd, Nodes *nodes) { + std::uniform_int_distribution distribution(0, nodes->size() - 1); + return distribution(*rnd); +} + +// Returns the index of a randomly chosen edge in *edges. +// Requires *edges be non-empty. +static int RandomEdge(std::mt19937 *rnd, Edges *edges) { + std::uniform_int_distribution distribution(0, edges->size() - 1); + return distribution(*rnd); +} + +// Returns the index of edge (from, to) in *edges or -1 if it is not in *edges. +static int EdgeIndex(Edges *edges, int from, int to) { + int i = 0; + while (i != edges->size() && + ((*edges)[i].from != from || (*edges)[i].to != to)) { + i++; + } + return i == edges->size() ? -1 : i; +} + +TEST(GraphCycles, RandomizedTest) { + Nodes nodes; + Edges edges; // from, to + tensorflow::GraphCycles graph_cycles; + static const int kMaxNodes = 7; // use <= 7 nodes to keep test short + static const int kDataOffset = 17; // an offset to the node-specific data + int n = 100000; + int op = 0; + std::mt19937 rnd(tensorflow::testing::RandomSeed() + 1); + + for (int iter = 0; iter != n; iter++) { + if ((iter % 10000) == 0) VLOG(0) << "Iter " << iter << " of " << n; + + if (VLOG_IS_ON(3)) { + LOG(INFO) << "==============="; + LOG(INFO) << "last op " << op; + PrintNodes(&nodes); + PrintEdges(&edges); + PrintGCEdges(&nodes, &graph_cycles); + } + for (int i = 0; i != nodes.size(); i++) { + ASSERT_EQ(reinterpret_cast(graph_cycles.GetNodeData(i)), + i + kDataOffset) + << " node " << i; + } + CheckEdges(&nodes, &edges, &graph_cycles); + CheckTransitiveClosure(&nodes, &edges, &graph_cycles); + std::uniform_int_distribution distribution(0, 5); + op = distribution(rnd); + switch (op) { + case 0: // Add a node + if (nodes.size() < kMaxNodes) { + int new_node = graph_cycles.NewNode(); + ASSERT_NE(-1, new_node); + VLOG(1) << "adding node " << new_node; + ASSERT_EQ(0, graph_cycles.GetNodeData(new_node)); + graph_cycles.SetNodeData( + new_node, reinterpret_cast( + static_cast(new_node + kDataOffset))); + ASSERT_GE(new_node, 0); + for (int i = 0; i != nodes.size(); i++) { + ASSERT_NE(nodes[i], new_node); + } + nodes.push_back(new_node); + } + break; + + case 1: // Remove a node + if (nodes.size() > 0) { + int node_index = RandomNode(&rnd, &nodes); + int node = nodes[node_index]; + nodes[node_index] = nodes.back(); + nodes.pop_back(); + VLOG(1) << "removing node " << node; + graph_cycles.RemoveNode(node); + int i = 0; + while (i != edges.size()) { + if (edges[i].from == node || edges[i].to == node) { + edges[i] = edges.back(); + edges.pop_back(); + } else { + i++; + } + } + } + break; + + case 2: // Add an edge + if (nodes.size() > 0) { + int from = RandomNode(&rnd, &nodes); + int to = RandomNode(&rnd, &nodes); + if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) { + if (graph_cycles.InsertEdge(nodes[from], nodes[to])) { + Edge new_edge; + new_edge.from = nodes[from]; + new_edge.to = nodes[to]; + edges.push_back(new_edge); + } else { + std::unordered_set seen; + ASSERT_TRUE(IsReachable(&edges, nodes[to], nodes[from], &seen)) + << "Edge " << nodes[to] << "->" << nodes[from]; + } + } + } + break; + + case 3: // Remove an edge + if (edges.size() > 0) { + int i = RandomEdge(&rnd, &edges); + int from = edges[i].from; + int to = edges[i].to; + ASSERT_EQ(i, EdgeIndex(&edges, from, to)); + edges[i] = edges.back(); + edges.pop_back(); + ASSERT_EQ(-1, EdgeIndex(&edges, from, to)); + VLOG(1) << "removing edge " << from << " " << to; + graph_cycles.RemoveEdge(from, to); + } + break; + + case 4: // Check a path + if (nodes.size() > 0) { + int from = RandomNode(&rnd, &nodes); + int to = RandomNode(&rnd, &nodes); + int32 path[2 * kMaxNodes]; + int path_len = graph_cycles.FindPath(nodes[from], nodes[to], + 2 * kMaxNodes, path); + std::unordered_set seen; + bool reachable = IsReachable(&edges, nodes[from], nodes[to], &seen); + bool gc_reachable = graph_cycles.IsReachable(nodes[from], nodes[to]); + ASSERT_EQ(gc_reachable, + graph_cycles.IsReachableNonConst(nodes[from], nodes[to])); + ASSERT_EQ(path_len != 0, reachable); + ASSERT_EQ(path_len != 0, gc_reachable); + // In the following line, we add one because a node can appear + // twice, if the path is from that node to itself, perhaps via + // every other node. + ASSERT_LE(path_len, kMaxNodes + 1); + if (path_len != 0) { + ASSERT_EQ(nodes[from], path[0]); + ASSERT_EQ(nodes[to], path[path_len - 1]); + for (int i = 1; i < path_len; i++) { + ASSERT_NE(-1, EdgeIndex(&edges, path[i - 1], path[i])); + ASSERT_TRUE(graph_cycles.HasEdge(path[i - 1], path[i])); + } + } + } + break; + + case 5: // Check invariants + CHECK(graph_cycles.CheckInvariants()); + break; + + default: + LOG(FATAL); + } + + // Very rarely, test graph expansion by adding then removing many nodes. + std::bernoulli_distribution rarely(1.0 / 1024.0); + if (rarely(rnd)) { + VLOG(3) << "Graph expansion"; + CheckEdges(&nodes, &edges, &graph_cycles); + CheckTransitiveClosure(&nodes, &edges, &graph_cycles); + for (int i = 0; i != 256; i++) { + int new_node = graph_cycles.NewNode(); + ASSERT_NE(-1, new_node); + VLOG(1) << "adding node " << new_node; + ASSERT_GE(new_node, 0); + ASSERT_EQ(0, graph_cycles.GetNodeData(new_node)); + graph_cycles.SetNodeData( + new_node, reinterpret_cast( + static_cast(new_node + kDataOffset))); + for (int j = 0; j != nodes.size(); j++) { + ASSERT_NE(nodes[j], new_node); + } + nodes.push_back(new_node); + } + for (int i = 0; i != 256; i++) { + ASSERT_GT(nodes.size(), 0); + int node_index = RandomNode(&rnd, &nodes); + int node = nodes[node_index]; + nodes[node_index] = nodes.back(); + nodes.pop_back(); + VLOG(1) << "removing node " << node; + graph_cycles.RemoveNode(node); + int j = 0; + while (j != edges.size()) { + if (edges[j].from == node || edges[j].to == node) { + edges[j] = edges.back(); + edges.pop_back(); + } else { + j++; + } + } + } + CHECK(graph_cycles.CheckInvariants()); + } + } +} + +class GraphCyclesTest : public ::testing::Test { + public: + tensorflow::GraphCycles g_; + + // Test relies on ith NewNode() call returning Node numbered i + GraphCyclesTest() { + for (int i = 0; i < 100; i++) { + CHECK_EQ(i, g_.NewNode()); + } + CHECK(g_.CheckInvariants()); + } + + bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); } + + void AddMultiples() { + // For every node x > 0: add edge to 2*x, 3*x + for (int x = 1; x < 25; x++) { + EXPECT_TRUE(AddEdge(x, 2 * x)) << x; + EXPECT_TRUE(AddEdge(x, 3 * x)) << x; + } + CHECK(g_.CheckInvariants()); + } + + string Path(int x, int y) { + static const int kPathSize = 5; + int32 path[kPathSize]; + int np = g_.FindPath(x, y, kPathSize, path); + string result; + for (int i = 0; i < np; i++) { + if (i >= kPathSize) { + result += " ..."; + break; + } + if (!result.empty()) result.push_back(' '); + char buf[20]; + snprintf(buf, sizeof(buf), "%d", path[i]); + result += buf; + } + return result; + } +}; + +TEST_F(GraphCyclesTest, NoCycle) { + AddMultiples(); + CHECK(g_.CheckInvariants()); +} + +TEST_F(GraphCyclesTest, SimpleCycle) { + AddMultiples(); + EXPECT_FALSE(AddEdge(8, 4)); + EXPECT_EQ("4 8", Path(4, 8)); + CHECK(g_.CheckInvariants()); +} + +TEST_F(GraphCyclesTest, IndirectCycle) { + AddMultiples(); + EXPECT_TRUE(AddEdge(16, 9)); + CHECK(g_.CheckInvariants()); + EXPECT_FALSE(AddEdge(9, 2)); + EXPECT_EQ("2 4 8 16 9", Path(2, 9)); + CHECK(g_.CheckInvariants()); +} + +TEST_F(GraphCyclesTest, LongPath) { + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(4, 6)); + ASSERT_TRUE(AddEdge(6, 8)); + ASSERT_TRUE(AddEdge(8, 10)); + ASSERT_TRUE(AddEdge(10, 12)); + ASSERT_FALSE(AddEdge(12, 2)); + EXPECT_EQ("2 4 6 8 10 ...", Path(2, 12)); + CHECK(g_.CheckInvariants()); +} + +TEST_F(GraphCyclesTest, RemoveNode) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(3, 4)); + ASSERT_TRUE(AddEdge(4, 5)); + g_.RemoveNode(3); + ASSERT_TRUE(AddEdge(5, 1)); +} + +TEST_F(GraphCyclesTest, ManyEdges) { + const int N = 50; + for (int i = 0; i < N; i++) { + for (int j = 1; j < N; j++) { + ASSERT_TRUE(AddEdge(i, i + j)); + } + } + CHECK(g_.CheckInvariants()); + ASSERT_TRUE(AddEdge(2 * N - 1, 0)); + CHECK(g_.CheckInvariants()); + ASSERT_FALSE(AddEdge(10, 9)); + CHECK(g_.CheckInvariants()); +} + +TEST_F(GraphCyclesTest, ContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + EXPECT_FALSE(g_.ContractEdge(1, 3)); + CHECK(g_.CheckInvariants()); + EXPECT_TRUE(g_.HasEdge(1, 3)); + + EXPECT_TRUE(g_.ContractEdge(1, 2)); + CHECK(g_.CheckInvariants()); + EXPECT_TRUE(g_.HasEdge(1, 3)); + EXPECT_TRUE(g_.HasEdge(1, 4)); + EXPECT_TRUE(g_.HasEdge(3, 4)); + + EXPECT_TRUE(g_.ContractEdge(1, 3)); + CHECK(g_.CheckInvariants()); + EXPECT_TRUE(g_.HasEdge(1, 4)); +} + +static void BM_StressTest(int iters, int num_nodes) { + while (iters > 0) { + tensorflow::GraphCycles g; + int32 *nodes = new int32[num_nodes]; + for (int i = 0; i < num_nodes; i++) { + nodes[i] = g.NewNode(); + } + for (int i = 0; i < num_nodes && iters > 0; i++, iters--) { + int end = std::min(num_nodes, i + 5); + for (int j = i + 1; j < end; j++) { + if (nodes[i] >= 0 && nodes[j] >= 0) { + CHECK(g.InsertEdge(nodes[i], nodes[j])); + } + } + } + delete[] nodes; + } +} +BENCHMARK(BM_StressTest)->Range(2048, 1048576); diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc new file mode 100644 index 0000000000..4d49a14b24 --- /dev/null +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, + MarkForCompilationPass); + +// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We +// also need to run it after the graph been rewritten to have _Send nodes added +// for fetches. Before the _Send nodes are added, fetch nodes are identified by +// name, and encapsulation might remove that node from the graph. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + EncapsulateSubgraphsPass); + +// Must run after EncapsulateSubgraphsPass. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, + BuildXlaLaunchOpsPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD new file mode 100644 index 0000000000..4491dd6ac8 --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -0,0 +1,67 @@ +# Legacy command line flags for the XLA bridge libraries. + +# Please do not add more flags to this package. + +# The XLA bridge libraries were written in an environment that allowed +# command-line flags to be scattered freely throughout the libraries. This +# model, while initially convenient, leads to a proliferation in unused command +# line flags in tests and binaries, and serious problems in servers, where one +# might wish parameters to be different in independent RPC calls to the same +# routine. +# +# Please don't add more flags. If you're a library author, pass options and +# parameters explicitly through the library's interface. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +cc_library( + name = "encapsulate_subgraphs_pass_flags", + srcs = ["encapsulate_subgraphs_pass_flags.cc"], + hdrs = ["encapsulate_subgraphs_pass_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "mark_for_compilation_pass_flags", + srcs = ["mark_for_compilation_pass_flags.cc"], + hdrs = ["mark_for_compilation_pass_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "parallel_check_op_flags", + srcs = ["parallel_check_op_flags.cc"], + hdrs = ["parallel_check_op_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc new file mode 100644 index 0000000000..856475f12c --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. + +#include +#include + +#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static EncapsulateSubgraphsPassFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new EncapsulateSubgraphsPassFlags; + flags->tf_xla_parallel_checking = false; + flag_list = new std::vector({ + Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking, + "Debug tool. Runs both JIT-compiled and interpreted graphs in " + "parallel and verifies they produce the same outputs."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with the XLA bridge's +// encapsulate_subgraphs_pass module. +void AppendEncapsulateSubgraphsPassFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the EncapsulateSubgraphsPassFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h new file mode 100644 index 0000000000..d371bd269d --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ + +// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with the XLA bridge's +// encapsulate_subgraphs_pass module. +void AppendEncapsulateSubgraphsPassFlags( + std::vector* flag_list); + +// The values of flags associated with the XLA bridge's +// encapsulate_subgraphs_pass module. +typedef struct { + bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and + // interpreted graphs in parallel and verifies + // they produce the same outputs. +} EncapsulateSubgraphsPassFlags; + +// Return a pointer to the EncapsulateSubgraphsPassFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc new file mode 100644 index 0000000000..09aee39d8c --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's mark_for_compilation_pass module. + +#include +#include + +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static MarkForCompilationPassFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new MarkForCompilationPassFlags; + flags->tf_xla_auto_jit = 0; + flags->tf_xla_min_cluster_size = 2; + flags->tf_xla_max_cluster_size = std::numeric_limits::max(); + flags->tf_xla_clustering_debug = false; + flag_list = new std::vector({ + Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with the XLA bridge's +// mark_for_compilation_pass module. +void AppendMarkForCompilationPassFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the MarkForCompilationPassFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h new file mode 100644 index 0000000000..24f8050742 --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ + +// Legacy flags for the XLA bridge's mark_for_compilation_pass module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with the XLA bridge's +// mark_for_compilation_pass module. +void AppendMarkForCompilationPassFlags( + std::vector* flag_list); + +// The values of flags associated with the XLA bridge's +// mark_for_compilation_pass module. +typedef struct { + int32 tf_xla_auto_jit; // Control compilation of operators into XLA + // computations on CPU and GPU devices. 0 = use + // ConfigProto setting; -1 = off; 1 = on for things + // very likely to be improved; 2 = on for everything. + // Experimental. + int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA + // compilation. Ignored for operators placed + // on an XLA device or operators explicitly + // marked for compilation. + int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA + // compilation. + bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. +} MarkForCompilationPassFlags; + +// Return a pointer to the MarkForCompilationPassFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc new file mode 100644 index 0000000000..a61694b494 --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's parallel_check_op module. + +#include +#include + +#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static ParallelCheckOpFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new ParallelCheckOpFlags; + flags->parallel_check_failfast = true; + flags->parallel_check_atol = "1e-5"; + flags->parallel_check_rtol = "1e-5"; + flag_list = new std::vector({ + Flag("parallel_check_failfast", &flags->parallel_check_failfast, + "Fail immediately on first parallel-check comparison error."), + Flag("parallel_check_atol", &flags->parallel_check_atol, + "Absolute error tolerance for parallel-check comparison."), + Flag("parallel_check_rtol", &flags->parallel_check_rtol, + "Relative error tolerance for parallel-check comparison."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with the XLA bridge's +// parallel_check_op module. +void AppendParallelCheckOpFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the ParallelCheckOpFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +ParallelCheckOpFlags* GetParallelCheckOpFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h new file mode 100644 index 0000000000..156a2a2a71 --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ + +// Legacy flags for the XLA bridge's parallel_check_op module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with the XLA bridge's +// parallel_check_op module. +void AppendParallelCheckOpFlags(std::vector* flag_list); + +// The values of flags associated with the XLA bridge's +// parallel_check_op module. +typedef struct { + bool parallel_check_failfast; // Fail immediately on first parallel-check + // comparison error. + string parallel_check_atol; // Absolute error tolerance for parallel-check + // comparison. + string parallel_check_rtol; // Relative error tolerance for parallel-check + // comparison. +} ParallelCheckOpFlags; + +// Return a pointer to the ParallelCheckOpFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +ParallelCheckOpFlags* GetParallelCheckOpFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc new file mode 100644 index 0000000000..486725f1da --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -0,0 +1,534 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; + +namespace { + +bool HasXLAKernel(const NodeDef& node_def, DeviceType jit_device_type) { + // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient + // is really a kind of function call and will be handled by + // IsCompilableCall(). + if (node_def.op() == "SymbolicGradient") return false; + return FindKernelDef(jit_device_type, node_def, nullptr, nullptr).ok(); +} + +// Make sure we don't recurse infinitely on recursive functions. +const int kMaxRecursionDepth = 5; + +bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, + int depth, FunctionLibraryRuntime* lib_runtime); + +// Tests whether 'while_def' is a completely compilable loop. +// Every operator in the condition and body functions must be compilable for a +// while loop to be compilable. +bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, + int depth, FunctionLibraryRuntime* lib_runtime) { + VLOG(2) << "Loop marking: " << while_def.op(); + + const NameAttrList* name_attr; + NodeDef call; + Status status; + status = GetNodeAttr(while_def, "cond", &name_attr); + if (!status.ok()) { + VLOG(2) << "Missing 'cond' attribute on While node."; + return false; + } + const string cond_func = name_attr->name(); + call.set_name("while_cond"); + call.set_op(cond_func); + *call.mutable_attr() = name_attr->attr(); + if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + VLOG(2) << "Can't compile loop condition: " << cond_func; + return false; + } + status = GetNodeAttr(while_def, "body", &name_attr); + if (!status.ok()) { + VLOG(2) << "Missing 'body' attribute on While node."; + return false; + } + const string body_func = name_attr->name(); + call.set_name("while_body"); + call.set_op(body_func); + *call.mutable_attr() = name_attr->attr(); + if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + VLOG(2) << "Can't compile loop body: " << body_func; + return false; + } + VLOG(2) << "Loop is compilable."; + return true; +} + +// Tests whether 'call_def' is a call to a completely compilable function. +// Every operator in the function must be compilable for a function to be +// compilable. +bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, + int depth, FunctionLibraryRuntime* lib_runtime) { + VLOG(2) << "Function marking: " << call_def.op(); + + if (depth > kMaxRecursionDepth) { + VLOG(2) << "Function depth limit exceeded"; + return false; + } + + FunctionLibraryRuntime::Handle handle; + Status status = + lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle); + if (!status.ok()) { + VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; + return false; + } + const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); + CHECK(fbody); + + for (Node* node : fbody->graph->nodes()) { + if (node->IsSource() || node->IsSink()) continue; + if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue; + if (node->def().op() == "While") { + // Handle functional While loop (not in open source build). + return IsCompilableWhile(node->def(), jit_device_type, depth + 1, + lib_runtime); + } + if (!HasXLAKernel(node->def(), jit_device_type) && + !IsCompilableCall(node->def(), jit_device_type, depth + 1, + lib_runtime)) { + VLOG(2) << "Function marking failed: unsupported op " << node->name() + << ": " << node->def().ShortDebugString(); + return false; + } + } + VLOG(2) << "Function is compilable: " << call_def.op(); + return true; +} + +// Returns the DeviceType corresponding to 'device'. +Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +Status FindCompilationCandidates( + const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, + const std::function& is_compilable_fn, + std::unordered_set* candidates) { + OptimizerOptions opts; + std::unique_ptr lib_runtime(NewFunctionLibraryRuntime( + nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts)); + + for (Node* node : graph.nodes()) { + if (node->IsSource() || node->IsSink()) continue; + + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + + if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + + const string* jit_device_name; + CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name, + /*requires_jit=*/nullptr)); + DeviceType jit_device_type(*jit_device_name); + if (!HasXLAKernel(node->def(), jit_device_type) && + !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) { + VLOG(2) << "Compilation rejected node: unsupported op " << node->name() + << ": " << node->def().op(); + continue; + } + if (node->def().op() == "While" && + !IsCompilableWhile(node->def(), jit_device_type, 0, + lib_runtime.get())) { + continue; + } + candidates->insert(node); + } + return Status::OK(); +} + +// Union-Find data structure used to compute clusters. We use our own +// implementation because we want one key feature: when merging clusters, we +// need to know which value becomes the representative of the merged clusters. +// We use the representatives to name nodes in a cycle detection graph, and we +// need to control which node is named. +// TODO(phawkins): consider merging this code with union-find implementations +// in Tensorflow, e.g., in SimplePlacer. +class Cluster { + public: + Cluster(); + + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's representative becomes + // the representative of the merged cluster; the representative of 'other' + // is ignored. + void Merge(Cluster* other); + + // Each cluster has an associated integer 'representative', initialized to -1 + // by default. + int GetRepresentative() { return FindRoot()->representative_; } + void SetRepresentative(int representative) { + FindRoot()->representative_ = representative; + } + + private: + // Finds the root element of the cluster. Performs path compression. + Cluster* FindRoot(); + + int representative_; + int rank_; + int size_; // Size of the cluster. + Cluster* parent_; +}; + +Cluster::Cluster() + : representative_(-1), rank_(0), size_(1), parent_(nullptr) {} + +void Cluster::Merge(Cluster* other) { + Cluster* a = FindRoot(); + Cluster* b = other->FindRoot(); + if (a == b) return; + if (a->rank_ > b->rank_) { + b->parent_ = a; + a->size_ += b->size_; + return; + } + + a->parent_ = b; + if (a->rank_ == b->rank_) { + b->rank_++; + } + b->representative_ = a->representative_; + b->size_ += a->size_; +} + +Cluster* Cluster::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // anonymous namespace + +bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { + Device* device = flr->device(); + const string* jit_device_name; + CHECK(XlaOpRegistry::GetJitDevice(device->device_type(), &jit_device_name, + /*requires_jit=*/nullptr)); + DeviceType jit_device_type(*jit_device_name); + return IsCompilableCall(ndef, jit_device_type, 0, flr); +} + +Status MarkForCompilationPass::Run( + const GraphOptimizationPassOptions& options) { + // TODO(phawkins): precompute the "GetJitDevice" properties each device ahead + // of time. + OptimizerOptions::GlobalJitLevel global_jit_level = + options.session_options->config.graph_options() + .optimizer_options() + .global_jit_level(); + if (global_jit_level == OptimizerOptions::DEFAULT) { + // To set compilation to be on by default, change the following line. + global_jit_level = OptimizerOptions::OFF; + } + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + if (flags->tf_xla_auto_jit == -1 || + (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { + // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides + // the setting in ConfigProto. + global_jit_level = + static_cast(flags->tf_xla_auto_jit); + } + const FunctionLibraryDefinition* fld = options.flib_def; + auto is_compilable = [global_jit_level, fld](const Node* node, + const DeviceType& device_type) { + const string* jit_device; + bool requires_jit; + if (!XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device, + &requires_jit)) { + return false; + } + // If this device requires a JIT, we must say yes. + if (requires_jit) return true; + + // If there is a _XlaCompile annotation, use its value. + bool compile = false; + Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile); + if (status.ok()) return compile; + + status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile); + if (status.ok()) return compile; + + // Otherwise use the value of global_jit_level. + return global_jit_level > 0; + }; + return RunImpl(options, is_compilable); +} + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "Rank" || + node.type_string() == "Size"; +} + +// Sequence number generator to ensure clusters have unique names. +static std::atomic cluster_sequence_num; + +Status MarkForCompilationPass::RunImpl( + const GraphOptimizationPassOptions& options, + const std::function& + is_compilable_fn) { + VLOG(1) << "MarkForCompilationPass::Run"; + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterJitKernels(); + + Graph* graph = options.graph->get(); + + std::unordered_set compilation_candidates; + TF_RETURN_IF_ERROR(FindCompilationCandidates( + *graph, options.flib_def, + (options.session_options != nullptr) ? options.session_options->env + : Env::Default(), + is_compilable_fn, &compilation_candidates)); + + GraphCycles cycles; + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles.NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles.NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + if (!cycles.InsertEdge(edge->src()->id(), + GetOrAddFrameNodeId(frame_name))) { + return errors::Internal("Cycle detected when adding enter->frame edge"); + } + continue; + } + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + if (!cycles.InsertEdge(GetOrAddFrameNodeId(frame_name), + edge->dst()->id())) { + return errors::Internal("Cycle detected when adding frame->exit edge"); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation."); + } + } + + // Each compilation candidate belongs to a cluster. The cluster's + // representative + // names the node in the 'cycles' graph that represents the cluster. + std::vector clusters(graph->num_node_ids()); + std::deque worklist; + for (Node* node : compilation_candidates) { + clusters[node->id()].SetRepresentative(node->id()); + worklist.push_back(&clusters[node->id()]); + } + + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. + while (!worklist.empty()) { + int from = worklist.front()->GetRepresentative(); + worklist.pop_front(); + + Node* node_from = graph->FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal("Found control flow node in clustering worklist"); + } + for (int to : cycles.Successors(from)) { + if (to >= graph->num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph->FindNodeId(to); + if (compilation_candidates.find(node_to) == compilation_candidates.cend()) + continue; + if (node_from->assigned_device_name() != node_to->assigned_device_name()) + continue; + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // Don't exceed the maximum cluster size. + if (clusters[from].Size() + clusters[to].Size() > + flags->tf_xla_max_cluster_size) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of elements in each cluster. + std::vector cluster_sizes(graph->num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].GetRepresentative(); + cluster_sizes[cluster]++; + } + + // Names for each cluster. + std::unordered_map cluster_names; + + // Mark clusters for compilation that: + // * are placed on a device that requires compilation (an XlaDevice), + // * are explicitly marked for compilation (_XlaCompile=true), or + // * have more than flags->tf_xla_min_cluster_size elements (applicable only + // if compilation is enabled, otherwise there will be no such candidates). + const int min_cluster_size = flags->tf_xla_min_cluster_size; + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].GetRepresentative(); + + // Compile if the user marked this node _XlaCompile=true + bool compile_attr = false; + bool marked_for_compilation = false; + if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) { + marked_for_compilation = compile_attr; + } else if (options.flib_def + ->GetAttr(n->def(), kXlaCompileAttr, &compile_attr) + .ok()) { + marked_for_compilation = compile_attr; + } + + // Compile if this operator is placed on a device that requires + // compilation. + bool requires_jit = false; + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + XlaOpRegistry::GetJitDevice(device_type.type(), + /*jit_device_name=*/nullptr, &requires_jit); + + // Or compile if this is a cluster of >= min_cluster_size compilable + // operators. + if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || + requires_jit) { + string& name = cluster_names[cluster]; + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + } + } + + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, + options.flib_def); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h new file mode 100644 index 0000000000..f91695800f --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An optimization passes that marks nodes that are to be compiled with +// attribute kXlaClusterAttr. Nodes with the same cluster ID will be compiled +// together. + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// Pass that marks a subset of operators in the graph with attribute +// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass. +class MarkForCompilationPass : public GraphOptimizationPass { + public: + MarkForCompilationPass() = default; + + Status Run(const GraphOptimizationPassOptions& options) override; + + // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass + // unconditionally, call RunImpl() directly. + // is_compilable_fn, if set, is a predicate that must be true for a node to + // be compiled. + Status RunImpl(const GraphOptimizationPassOptions& options, + const std::function& + is_compilable_fn = {}); +}; + +// Returns true iff 'ndef' is a call to a function that is compilable. A +// function is compilable iff every operator in the function body is +// compilable. +bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc new file mode 100644 index 0000000000..560695e87d --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -0,0 +1,357 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +void MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def) { + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + opt_options.flib_def = flib_def; + MarkForCompilationPass pass; + CHECK(pass.RunImpl(opt_options).ok()); +} + +void MarkForCompilation(std::unique_ptr* graph) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); + MarkForCompilation(graph, &flib_def); +} + +std::unordered_map GetClusters(const Graph& graph) { + std::unordered_map ids; + for (Node* node : graph.nodes()) { + string cluster; + if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node->name()] = cluster; + } + } + return ids; +} + +TEST(XlaCompilationTest, Chains) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST(XlaCompilationTest, UncompilableCycles) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + EXPECT_TRUE(clusters.empty()); +} + +TEST(XlaCompilationTest, CompilableCycles) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +TEST(XlaCompilationTest, UnsupportedTypes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_COMPLEX64) + .WithAttr("value", Tensor(DT_COMPLEX64, TensorShape()))); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + EXPECT_TRUE(clusters.empty()); +} + +TEST(XlaCompilationTest, ConcatWithConstArg) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + Tensor t(DT_INT32, TensorShape()); + t.scalar()() = 0; + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* dim = ops::SourceOp("Const", builder.opts() + .WithName("Dim") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", t)); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", t)); + + NodeBuilder concat_builder("Concat", "Concat", + builder.opts().op_registry()); + concat_builder.Input(dim).Input({a, a}).Attr("N", 2); + builder.opts().FinalizeBuilder(&concat_builder); + + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + EXPECT_EQ(3, clusters.size()); // Everything should be compiled. +} + +TEST(XlaCompilationTest, FunctionCalls) { + FunctionDefLibrary flib; + *flib.add_function() = FunctionDefHelper::Define( + "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, + {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); + *flib.add_function() = + FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, + {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph(new Graph(&flib_def)); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph, &flib_def); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_FALSE(clusters["B"].empty()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +// Metadata-only operators such as Shape/Rank/Size may not be the root of a +// cluster. This is partially to work around b/26800664, and partially because +// we should probably prefer to compile metadata operators with their producers +// wherever possible, rather than their consumers. +TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + // While all of the following ops are notionally compilable, none is + // permitted + // to start a cluster. So nothing should be compiled. + Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); + Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); + ops::UnaryOp("Shape", d, builder.opts().WithName("C")); + builder.ToGraph(graph.get()); + } + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. +} + +static Status GradForUnaryCwise(FunctionDef* g, + std::vector nodes) { + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", DT_FLOAT}}; + } + } + *g = FunctionDefHelper::Define( + // Arg defs + {"x: float", "dy: float"}, + // Ret val defs + {"dx: float"}, + // Attr defs + {}, + // Nodes + nodes); + return Status::OK(); +} + +// A gradient containing only supported operators +Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Tanh", {"x"}}, + {{"y2"}, "Square", {"y"}, {}, {"dy"}}, + FunctionDefHelper::Const("one", 1.0f), + {{"a"}, "Sub", {"one", "y2"}}, + {{"dx"}, "Mul", {"dy", "a"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Supported", SupportedGrad); + +// A gradient containing an unsupported operator. +Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Tanh", {"x"}}, + {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}}, + FunctionDefHelper::Const("one", 1.0f), + {{"a"}, "Sub", {"one", "y2"}}, + {{"dx"}, "Mul", {"dy", "a"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad); + +TEST(XlaCompilationTest, SymbolicGradients) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + + // Builds a Symbolic gradient for Supported + NodeBuilder b_builder("B", "SymbolicGradient", + builder.opts().op_registry()); + NameAttrList b_name_attr; + b_name_attr.set_name("Supported"); + b_builder.Attr("f", b_name_attr); + b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); + b_builder.Attr("Tout", {DT_FLOAT}); + b_builder.Input({a, a}); + Node* b = builder.opts().FinalizeBuilder(&b_builder); + + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + + // Builds a Symbolic gradient for Unsupported + NodeBuilder d_builder("D", "SymbolicGradient", + builder.opts().op_registry()); + NameAttrList d_name_attr; + d_name_attr.set_name("Unsupported"); + d_builder.Attr("f", d_name_attr); + d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); + d_builder.Attr("Tout", {DT_FLOAT}); + d_builder.Input({c, c}); + builder.opts().FinalizeBuilder(&d_builder); + + builder.ToGraph(graph.get()); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_FALSE(clusters["B"].empty()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST(XlaCompilationTest, Loops) { + // Regression test for b/32350199, where the autoclustering code introduced a + // deadlock in a graph containing a while loop. + Scope root = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT); + auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT); + auto c = ops::Add(root.WithOpName("C"), a, b); + auto enter = ops::Enter(root, c, "aframe"); + auto next_iter = ops::NextIteration(root, enter); + auto exit = ops::Exit(root, next_iter); + auto d = ops::Add(root.WithOpName("D"), c, exit); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + root.ToGraph(graph.get()); + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + // Nothing should be compiled. In particular, 'd' and 'c' must not be + // compiled. + EXPECT_EQ(0, clusters.size()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/parallel_check_op.cc b/tensorflow/compiler/jit/parallel_check_op.cc new file mode 100644 index 0000000000..d07da46ca0 --- /dev/null +++ b/tensorflow/compiler/jit/parallel_check_op.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("ParallelCheck") + .Attr("T: list(type) >= 0") + .Input("expected: T") + .Input("actual: T") + .Output("result: T") + .Doc(R"doc( +Op that compares two sets of inputs for near-identity, and propagates the first. +Inequality is logged to ERROR log. +)doc"); + +// Inputs 2*N tensors, outputs the first N inputs. +// Logs errors if input tensor i and i + N are not (near) identical +// in any position. +class ParallelCheckOp : public OpKernel { + public: + explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + template + int CompareTensors(DataType dtype, const char* v0, const char* v1, + int64 num_elts, int input_idx) { + int failed = 0; + const T* p0 = reinterpret_cast(v0); + const T* p1 = reinterpret_cast(v1); + double rtol; + legacy_flags::ParallelCheckOpFlags* flags = + legacy_flags::GetParallelCheckOpFlags(); + if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), + &rtol)) { + LOG(ERROR) << "can't convert parallel_check_rtol " + << flags->parallel_check_rtol << " to double"; + } + double atol; + if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), + &atol)) { + LOG(ERROR) << "can't convert parallel_check_atol " + << flags->parallel_check_atol << " to double"; + } + for (int i = 0; i < num_elts; ++i) { + bool ok = (p0[i] == p1[i]); + VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; + if (!ok) { + if (std::is_same::value || std::is_same::value) { + float tolerance = + std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); + T diff = p0[i] - p1[i]; + if (diff < 0) diff = 0 - diff; + ok = (diff <= tolerance); + } + if (ok) continue; + LOG(ERROR) << "Op " << def().name() << " fails equality at output " + << input_idx << " type " << DataTypeString(dtype) + << " element " << i << ": std_val=" << p0[i] + << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); + if (++failed > 10) break; + } + } + return failed; + } + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "Compute " << def().name(); + const int num_pairs = ctx->num_inputs() / 2; + for (int i = 0; i < num_pairs; ++i) { + CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); + Tensor t0 = ctx->input(i); + Tensor t1 = ctx->input(i + num_pairs); + int64 num_elts = t0.NumElements(); + CHECK_EQ(num_elts, t1.NumElements()); + + // Compare inputs elementwise for near-exact equality. + const char* v0 = t0.tensor_data().data(); + const char* v1 = t1.tensor_data().data(); + int failed = 0; + switch (ctx->input_dtype(i)) { + case DT_INT32: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_INT64: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_FLOAT: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_DOUBLE: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_BOOL: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + default: + LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); + } + if (failed > 0) { + LOG(ERROR) << "check failed for " << def().name() << " output " << i + << " num_elts: " << num_elts; + legacy_flags::ParallelCheckOpFlags* flags = + legacy_flags::GetParallelCheckOpFlags(); + if (flags->parallel_check_failfast) { + LOG(QFATAL) << "failfast on first parallel-check failure"; + } + } else { + VLOG(1) << "check passed for " << def().name() << " output " << i + << " num_elts: " << num_elts; + } + + // Propagate the std value. + if (IsRefType(ctx->input_dtype(i))) { + ctx->forward_ref_input_to_ref_output(i, i); + } else { + ctx->set_output(i, ctx->input(i)); + } + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); +}; + +REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), + ParallelCheckOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc new file mode 100644 index 0000000000..4644121173 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" + +#include + +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options) + : compiler_(options) {} + +XlaCompilationCache::~XlaCompilationCache() = default; + +string XlaCompilationCache::DebugString() { + return "XLA JIT compilation cache"; +} + +// Compute a string signature which encodes the shapes of the +// arguments in the supplied list. +string XlaCompilationCache::SignatureDebugString(const Signature& sig) { + string result = sig.name; + for (const auto& a : sig.arg_types) { + strings::StrAppend(&result, ",", DataTypeString(a.first), + a.second.DebugString()); + } + + for (const auto& v : sig.arg_values) { + strings::StrAppend(&result, "; ", v.first, ":", v.second.DebugString()); + } + return result; +} + +bool XlaCompilationCache::Signature::operator==(const Signature& other) const { + if (name != other.name) return false; + if (arg_types != other.arg_types) return false; + + if (arg_values.size() != other.arg_values.size()) return false; + for (int i = 0; i < arg_values.size(); ++i) { + if (arg_values[i].first != other.arg_values[i].first || + arg_values[i].second.tensor_data() != + other.arg_values[i].second.tensor_data()) { + return false; + } + } + return true; +} + +uint64 XlaCompilationCache::Signature::Hash::operator()( + const XlaCompilationCache::Signature& signature) const { + uint64 h = std::hash()(signature.name); + for (const auto& arg : signature.arg_types) { + h = Hash64Combine(h, std::hash()(static_cast(arg.first))); + h = Hash64Combine(h, std::hash()(arg.second.dims())); + for (int dim : arg.second.dim_sizes()) { + h = Hash64Combine(h, std::hash()(dim)); + } + } + for (const auto& arg : signature.arg_values) { + h = Hash64Combine(h, std::hash()(static_cast(arg.first))); + h = Hash64Combine(h, Hash64(arg.second.tensor_data().data(), + arg.second.tensor_data().size())); + } + return h; +} + +namespace { + +// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch +// op. The first `num_constant_args` arguments must be host-memory Tensors. +std::vector BuildArguments(int num_constant_args, + OpKernelContext* ctx) { + std::vector args(ctx->num_inputs()); + int parameter_num = 0; + for (int i = 0; i < ctx->num_inputs(); ++i) { + args[i].type = ctx->input(i).dtype(); + args[i].shape = ctx->input(i).shape(); + if (i < num_constant_args || ctx->input(i).NumElements() == 0) { + args[i].parameter = -1; + args[i].constant_value = ctx->input(i); + } else { + args[i].parameter = parameter_num; + ++parameter_num; + } + } + return args; +} + +} // namespace + +Status XlaCompilationCache::Compile( + const NameAttrList& function, int num_constant_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable) { + VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); + + if (VLOG_IS_ON(2)) { + std::vector argshapes; + VLOG(2) << "num_inputs = " << ctx->num_inputs() + << " num_constant_args= " << num_constant_args; + for (int i = 0; i < ctx->num_inputs(); i++) { + TensorShape shape = ctx->input(i).shape(); + VLOG(2) << i << ": dtype=" << ctx->input_dtype(i) + << " present=" << ctx->has_input(i) + << " shape=" << shape.DebugString(); + argshapes.push_back(shape.DebugString()); + } + VLOG(2) << "num_outputs = " << ctx->num_outputs(); + for (int i = 0; i < ctx->num_outputs(); i++) { + VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i); + } + } + Signature signature; + signature.name = Canonicalize(function.name(), function.attr()); + for (int i = 0; i < ctx->num_inputs(); ++i) { + signature.arg_types.emplace_back(ctx->input_dtype(i), + ctx->input(i).shape()); + if (i < num_constant_args) { + signature.arg_values.emplace_back(i, ctx->input(i)); + } + } + + VLOG(2) << "Signature: " << SignatureDebugString(signature); + // The outer lock protects the existence of the cache entry. It does not + // protect the contents of the cache entry. + Entry* entry; + { + mutex_lock lock(mu_); + // Find or create a cache entry. + std::unique_ptr& e = cache_[signature]; + if (!e) { + e.reset(new Entry); + } + entry = e.get(); + } + + // Acquire the cache entry lock and compile, if necessary. + // TODO(phawkins): this locking will need to be restructured when we implement + // cache eviction. + mutex_lock entry_lock(entry->mu); + if (!entry->compiled) { + // Do the actual JIT compilation without holding the lock (it can take + // a long time.) + std::vector args = + BuildArguments(num_constant_args, ctx); + + std::unique_ptr flr(NewFunctionLibraryRuntime( + compiler_.device_mgr(), ctx->env(), compiler_.device(), + TF_GRAPH_DEF_VERSION, + ctx->function_library()->GetFunctionLibraryDefinition(), + OptimizerOptions(), nullptr /* custom_kernel_creator */)); + + entry->compiled = true; + entry->compilation_status = compiler_.CompileFunction( + flr.get(), function, args, &entry->compilation_result); + } + *compilation_result = &entry->compilation_result; + if (entry->compilation_status.ok() && executable) { + if (entry->executable == nullptr && + !entry->compilation_result.computation.IsNull()) { + entry->compilation_status = compiler_.BuildExecutable( + entry->compilation_result, &entry->executable); + } + *executable = entry->executable.get(); + } + + Status status = entry->compilation_status; + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h new file mode 100644 index 0000000000..44d76db0fd --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// The XlaCompilationCache class caches the results of the XlaCompiler class, +// which converts a Tensorflow graph into a compiled XLA compilation. +// +// Since XLA computations must have static shapes, the cache generates a new +// XLA computation for each new set of input shapes. +// +// Currently no cache eviction policy is implemented and the cache grows without +// bound. +class XlaCompilationCache : public ResourceBase { + public: + explicit XlaCompilationCache(const XlaCompiler::Options& options); + ~XlaCompilationCache() override; + + // Compiles a function into a XlaCompiler::CompilationResult that can be used + // to execute an XLA Computation. `compilation_result` must be non-null. + // If `executable` is non-null, also builds an xla::LocalExecutable and sets + // `executable to point to it. The resulting executable pointer may be null if + // the computation has no non-constant outputs. + // Compilation results are cached. + Status Compile(const NameAttrList& function, int num_constant_args, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable); + + xla::Client* client() const { return compiler_.client(); } + + string DebugString() override; + + private: + XlaCompiler compiler_; + std::unique_ptr function_library_runtime_; + + // Describes the types, shapes and any compile-time constant arguments + // to a kernel. + struct Signature { + string name; + + std::vector> arg_types; + + // List of (argument #, value) pairs for arguments whose values are + // part of the JIT signature, and that are therefore constants in any given + // JIT compilation. Tensors must be in host memory. + std::vector> arg_values; + + bool operator==(const Signature& other) const; + + struct Hash { + uint64 operator()(const Signature& signature) const; + }; + }; + static string SignatureDebugString(const Signature& sig); + + // The value associated with a cache entry. + struct Entry { + mutex mu; + + // Have we tried compiling this entry? + bool compiled = false; + + // Did compilation succeed? + Status compilation_status GUARDED_BY(mu); + + // Output of the XlaCompiler. + XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu); + + // The XLA executable compiled from . May be null if no + // executable has been built. + std::unique_ptr executable GUARDED_BY(mu); + }; + + mutex mu_; + std::unordered_map, Signature::Hash> cache_ + GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc new file mode 100644 index 0000000000..92784a5358 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs +// operators using XLA via the XLA "Host" (CPU) backend. + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_ops.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +const char* const DEVICE_XLA_CPU = "XLA_CPU"; + +class XlaCpuDeviceFactory : public DeviceFactory { + public: + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector* devices) override; +}; + +Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices) { + static XlaDeviceOpRegistrations* registrations = + RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT); + (void)registrations; + + std::unique_ptr device; + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + &device)); + devices->push_back(device.release()); + return Status::OK(); +} + +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); + +// Kernel registrations + +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; + +REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaDeviceLaunchOp, kAllXlaCpuTypes); +REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc new file mode 100644 index 0000000000..a3fb5b4106 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device.cc @@ -0,0 +1,219 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_device.h" + +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_device_ops.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +/* static */ Status XlaDevice::Create( + const string& platform_name, const string& device_name, int device_ordinal, + const string& jit_device_name, const SessionOptions& options, + const string& name_prefix, std::unique_ptr* device) { + VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" + << device_ordinal; + + // These are no-ops if they have already been done previously for + // this device_name/jit_device_name pair. + XlaOpRegistry::RegisterJitKernels(); + XlaOpRegistry::RegisterJitDevice(device_name, jit_device_name, + /*requires_jit=*/true); + + auto platform = perftools::gputools::MultiPlatformManager::PlatformWithName( + platform_name); + if (!platform.ok()) { + return StreamExecutorUtil::ConvertStatus(platform.status()); + } + + const DeviceAttributes attrs = Device::BuildDeviceAttributes( + strings::StrCat(name_prefix, "/device:", device_name, ":", + device_ordinal), + DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), + strings::StrCat("device: ", device_name, " device")); + + static Allocator* allocator = new XlaDeviceAllocator; + device->reset(new XlaDevice(options, attrs, device_ordinal, + DeviceType(jit_device_name), + platform.ValueOrDie(), allocator)); + return Status::OK(); +} + +XlaDevice::Metadata::Metadata(int device_ordinal, + perftools::gputools::Platform* platform, + const DeviceType& device_type) + : device_ordinal_(device_ordinal), + device_type_(device_type), + platform_(platform) {} + +int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } + +perftools::gputools::Platform* XlaDevice::Metadata::platform() const { + return platform_; +} + +XlaDevice::Metadata::~Metadata() {} + +xla::LocalClient* XlaDevice::Metadata::client() const { + auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_); + return client.ValueOrDie(); +} + +const DeviceType& XlaDevice::Metadata::jit_device_type() const { + return device_type_; +} + +string XlaDevice::Metadata::DebugString() { return "XLA device metadata"; } + +XlaDevice::XlaDevice(const SessionOptions& options, + const DeviceAttributes& attrs, int device_ordinal, + const DeviceType& jit_device_name, + perftools::gputools::Platform* platform, + Allocator* xla_allocator) + : LocalDevice(options, attrs, xla_allocator), + device_ordinal_(device_ordinal), + jit_device_name_(jit_device_name), + xla_allocator_(xla_allocator), + platform_(platform) { + // Store the platform in the resource manager so Ops can retrieve it + // e.g., to lazily create a XlaCompilationCache object. + TF_CHECK_OK(resource_manager()->Create( + resource_manager()->default_container(), "xla_metadata", + new Metadata(device_ordinal_, platform_, jit_device_name_))); +} +XlaDevice::~XlaDevice() {} + +xla::LocalClient* XlaDevice::client() const { + // We lazily create the client because the platform commits to the + // details of the host hardware when the client is created, so we + // don't want to do it until we get a chance to hook the platform up + // to a simulator. + + // For now GetOrCreateLocalClient always returns success when passed + // a non-null platform. If that changes we may have to plumb in some + // way to pass Status back. + return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); +} + +Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { + if (attr.on_host()) { + return cpu_allocator(); + } else { + return xla_allocator_; + } +} + +Status XlaDevice::FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) { + VLOG(1) << "XlaDevice::FillContextMap"; + device_context_map->resize(graph->num_node_ids()); + XlaDeviceContext* ctx = new XlaDeviceContext(client()); + for (Node* n : graph->nodes()) { + VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); + ctx->Ref(); + (*device_context_map)[n->id()] = ctx; + } + return Status::OK(); +} + +void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { + VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" + << op_kernel->type_string(); + op_kernel->Compute(context); +} + +void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) { + VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" + << op_kernel->type_string(); + op_kernel->ComputeAsync(context, done); +} + +Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + VLOG(1) << "XlaDevice::MakeTensorFromProto"; + + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + + Status status; + if (alloc_attrs.on_host()) { + *tensor = parsed; + } else { + Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + Notification n; + XlaTransferManager manager(client()); + manager.CopyCPUTensorToDevice(&parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + *tensor = copy; + } + VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor); + return status; +} + +XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, + const char* jit_device) { + XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; + auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { + return new XlaDeviceDummyOp(context); + }; + for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(jit_device)) { + KernelDef* def = new KernelDef(*jit_def); + def->set_device_type(device); + registrations->op_kernel_registrars.emplace_back( + new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp", + dummy_factory)); + } + return registrations; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h new file mode 100644 index 0000000000..3de14f3061 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device.h @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaDevice executes a TensorFlow graph using the XLA linear algebra +// runtime. +// +// Operators assigned to an XlaDevice are compiled into XLA computations. +// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state +// is managed by XLA. +// +// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), +// under different names (e.g., XLA_CPU or XLA_GPU). + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace tensorflow { + +class XlaDevice : public LocalDevice { + public: + // Wrapper class to store metadata about the XlaDevice in the + // resource manager, where it can be looked up e.g., when lazily + // creating the XlaCompilationCache device. + class Metadata : public ResourceBase { + public: + Metadata(int device_ordinal, perftools::gputools::Platform* platform, + const DeviceType& device_type); + ~Metadata() override; + + // The index of the device on this host. + int device_ordinal() const; + + perftools::gputools::Platform* platform() const; + xla::LocalClient* client() const; + const DeviceType& jit_device_type() const; + + string DebugString() override; + + private: + const int device_ordinal_; + const DeviceType device_type_; + perftools::gputools::Platform* platform_; // Not owned. + }; + + // Factory function. 'platform_name' is the name of the XLA platform. + // 'device_name' is the name of the Tensorflow device to create. + // 'jit_device_name' is the name of the corresponding JIT device. + static Status Create(const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + std::unique_ptr* device); + + XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + ::perftools::gputools::Platform* platform, + Allocator* xla_allocator); + ~XlaDevice() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override; + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + Status Sync() override { return Status::OK(); } + + Status FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) override; + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + xla::LocalClient* client() const; + + private: + // Which hardware device in the client's platform this XlaDevice controls. + const int device_ordinal_; + // The name of the device that is used to compile Ops for this XlaDevice. + const DeviceType& jit_device_name_; + Allocator* xla_allocator_; // Not owned. + ::perftools::gputools::Platform* platform_; // Not owned. +}; + +// Builds dummy OpKernel registrations on 'device' for the JIT operators +// registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations +// object that encapsulates the kernel registrations. +struct XlaDeviceOpRegistrations { + std::vector> + op_kernel_registrars; +}; +XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, + const char* jit_device); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc new file mode 100644 index 0000000000..250960d395 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -0,0 +1,181 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_device_context.h" + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/common_runtime/dma_helper.h" + +namespace tensorflow { + +// The contents of tensors allocated by XlaDeviceAllocator. +struct XlaGlobalData { + mutable mutex mu; + // May be nullptr if there is no xla::GlobalData backing this Tensor. + std::shared_ptr data GUARDED_BY(mu); +}; + +// The allocator used for Tensors assigned to the XLA device. The allocator +// doesn't actually back Tensors with storage. Instead, each tensor contains +// a XlaGlobalData that wraps XLA-managed storage. +XlaDeviceAllocator::XlaDeviceAllocator() = default; +XlaDeviceAllocator::~XlaDeviceAllocator() = default; + +string XlaDeviceAllocator::Name() { return "xla"; } + +void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + // Regardless of the size requested, always allocate a XlaGlobalData. Respect + // the aligment request because there is alignment checking even for Tensors + // whose data is never accessed. + void* p = port::aligned_malloc(sizeof(XlaGlobalData), alignment); + VLOG(2) << "Allocated XLA device tensor " << p; + return new (p) XlaGlobalData(); +} + +void XlaDeviceAllocator::DeallocateRaw(void* ptr) { + XlaGlobalData* global_data = reinterpret_cast(ptr); + VLOG(2) << "Deallocated XLA device tensor " << ptr; + global_data->~XlaGlobalData(); + port::aligned_free(ptr); +} + +void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } + +// Don't run any constructors or destructors for complex objects, +// since there is no backing store for the tensor to run them +// on. strings are the only complex objects currently stored in +// Tensors. If others are added, this set of overrides must be +// extended to include them. +void XlaDeviceAllocator::RunStringCtor(string* p, size_t n) {} +void XlaDeviceAllocator::RunStringDtor(string* p, size_t n) {} +void XlaDeviceAllocator::RunResourceCtor(ResourceHandle* p, size_t n) {} +void XlaDeviceAllocator::RunResourceDtor(ResourceHandle* p, size_t n) {} + +static const XlaGlobalData* CastTensorToXlaGlobalData(const Tensor& tensor) { + const XlaGlobalData* expression = + reinterpret_cast(tensor.tensor_data().data()); + return expression; +} + +static XlaGlobalData* CastTensorToXlaGlobalData(Tensor* tensor) { + const XlaGlobalData* expression = + reinterpret_cast(tensor->tensor_data().data()); + return const_cast(expression); +} + +XlaTransferManager::XlaTransferManager(xla::Client* client) : client_(client) {} + +void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor, + StatusCallback done) const { + if (cpu_tensor->NumElements() > 0) { + VLOG(2) << "CopyCPUTensorToDevice " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " << reinterpret_cast( + device_tensor->tensor_data().data()) + << cpu_tensor->NumElements(); + xla::Literal literal; + Status status = HostTensorToLiteral(*cpu_tensor, &literal); + if (!status.ok()) { + done(status); + return; + } + auto gd = client_->TransferToServer(literal); + if (!gd.ok()) { + done(gd.status()); + return; + } + SetTensorGlobalData( + std::shared_ptr(std::move(gd.ValueOrDie())), + device_tensor); + } else { + VLOG(2) << "CopyCPUTensorToDevice empty tensor"; + } + done(Status::OK()); +} + +void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, + Device* device, + Tensor* cpu_tensor, + StatusCallback done) { + if (device_tensor->NumElements() > 0) { + VLOG(2) << "CopyDeviceTensorToCPU" + << reinterpret_cast( + device_tensor->tensor_data().data()) + << " " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << device_tensor->NumElements(); + std::shared_ptr global_data = + GetTensorGlobalData(*device_tensor); + + xla::Shape shape; + Status status = + TensorShapeToXLAShape(cpu_tensor->dtype(), cpu_tensor->shape(), &shape); + if (!status.ok()) { + done(status); + return; + } + auto result = client_->Transfer(*global_data, &shape); + if (!result.ok()) { + done(result.status()); + return; + } + const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie()); + void* dst_ptr = DMAHelper::base(cpu_tensor); + size_t total_bytes = cpu_tensor->TotalBytes(); + memcpy(dst_ptr, src_ptr, total_bytes); + } else { + VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; + } + done(Status::OK()); +} + +std::shared_ptr XlaTransferManager::GetTensorGlobalData( + const Tensor& tensor) { + const XlaGlobalData* data = CastTensorToXlaGlobalData(tensor); + mutex_lock lock(data->mu); + CHECK(data->data); + return data->data; +} + +void XlaTransferManager::SetTensorGlobalData( + std::shared_ptr global_data, Tensor* tensor) { + XlaGlobalData* data = CastTensorToXlaGlobalData(tensor); + mutex_lock lock(data->mu); + data->data = std::move(global_data); +} + +XlaDeviceContext::XlaDeviceContext(xla::Client* client) : manager_(client) {} + +void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor, + StatusCallback done) const { + manager_.CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, done); +} + +void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) { + manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, + done); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h new file mode 100644 index 0000000000..8ab462b615 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// The allocator used for Tensors assigned to the XLA device. The allocator +// doesn't actually back Tensors with storage. Instead, each tensor is a thin +// wrapper around XLA-managed storage. +class XlaDeviceAllocator : public Allocator { + public: + XlaDeviceAllocator(); + ~XlaDeviceAllocator() override; + + string Name() override; + + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + void GetStats(AllocatorStats* stats) override; + + private: + void RunStringCtor(string* p, size_t n) override; + void RunStringDtor(string* p, size_t n) override; + void RunResourceCtor(ResourceHandle* p, size_t n) override; + void RunResourceDtor(ResourceHandle* p, size_t n) override; +}; + +// Helper class for managing data transfers between host and XLA devices. +class XlaTransferManager { + public: + explicit XlaTransferManager(xla::Client* client); + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done) const; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done); + + // Helper methods to get/set the xla::GlobalData backing a Tensor on the + // XlaDevice. + static std::shared_ptr GetTensorGlobalData( + const Tensor& tensor); + static void SetTensorGlobalData(std::shared_ptr global_data, + Tensor* tensor); + + private: + xla::Client* client_; +}; + +// DeviceContext for operators assigned to XlaDevice devices. The +// implementation must inherit from DeviceContext but otherwise just +// wraps the methods in XlaTransferManager. +class XlaDeviceContext : public DeviceContext { + public: + explicit XlaDeviceContext(xla::Client* client); + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, + StatusCallback done) const override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + private: + XlaTransferManager manager_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ diff --git a/tensorflow/compiler/jit/xla_device_launch_op.cc b/tensorflow/compiler/jit/xla_device_launch_op.cc new file mode 100644 index 0000000000..becfbc1389 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_launch_op.cc @@ -0,0 +1,171 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_device_launch_op.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +namespace { + +Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) { + XlaDevice::Metadata* metadata; + Status s = rm->Lookup(rm->default_container(), + "xla_metadata", &metadata); + if (!s.ok()) { + return s; + } + core::ScopedUnref metadata_ref(metadata); + XlaCompiler::Options options; + options.device_type = metadata->jit_device_type(); + options.client = metadata->client(); + options.allow_cpu_custom_calls = false; + options.local_executable_has_hybrid_result = false; + *compiler = new XlaCompilationCache(options); + return Status::OK(); +} + +} // namespace + +XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); + function_ = *func; + VLOG(1) << "XlaDeviceLaunch created function=" + << Canonicalize(function_.name(), function_.attr()); + DataTypeVector constant_types; + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); + num_constant_args_ = constant_types.size(); +} + +void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaDeviceLaunch::Compute " + << Canonicalize(function_.name(), function_.attr()); + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + + XlaCompilationCache* compiler; + OP_REQUIRES_OK(ctx, + rm->LookupOrCreate( + rm->default_container(), "xla_compiler", &compiler, + [rm](XlaCompilationCache** compiler) { + return BuildCompilationCache(rm, compiler); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref compiler_ref(compiler); + + const XlaCompiler::CompilationResult* kernel; + OP_REQUIRES_OK( + ctx, + compiler->Compile(function_, num_constant_args_, ctx, &kernel, nullptr)); + + VLOG(1) << "Executing XLA Computation..."; + + OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(), + errors::Internal("Unexpected number of outputs")); + + // Run the computation, if any. There might not be a computation if all + // outputs were compile-time constants. + std::vector> outputs; + if (!kernel->computation.IsNull()) { + auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); + + // Convert argument tensors to xla::GlobalData pointers. + std::vector> arg_handles( + kernel->xla_input_shapes.size()); + std::vector arg_ptrs(kernel->xla_input_shapes.size()); + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int input_num = kernel->xla_input_shapes[i].first; + arg_handles[i] = + XlaTransferManager::GetTensorGlobalData(ctx->input(input_num)); + arg_ptrs[i] = arg_handles[i].get(); + } + + // Execute the computation. + xla::ExecutionProfile profile; + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + auto result = compiler->client()->Execute( + kernel->computation, arg_ptrs, &kernel->xla_output_shape, &profile); + auto elapsed = env->NowMicros() - start_time; + OP_REQUIRES(ctx, result.ok(), result.status()); + + VLOG(1) << "Elapsed time: " << elapsed << "us"; + VLOG(1) << "ExecutionProfile: " << profile.DebugString(); + + if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) { + auto outputs_or_error = + compiler->client()->DeconstructTuple(*result.ValueOrDie()); + OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status()); + outputs = outputs_or_error.ConsumeValueOrDie(); + } else { + outputs.push_back(result.ConsumeValueOrDie()); + } + } + + XlaDeviceContext* device_context = ctx->op_device_context(); + + // Copy XLA outputs to the operator's outputs. + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + Tensor* output; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(i, kernel->outputs[i].shape, &output)); + if (kernel->outputs[i].is_constant) { + // TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and + // remove the copy from this code. + Status status; + device_context->CopyCPUTensorToDevice( + &kernel->outputs[i].constant_value, nullptr, output, + [&status](const Status& s) { status = s; }); + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + } else { + CHECK_LT(output_num, outputs.size()); + XlaTransferManager::SetTensorGlobalData( + std::shared_ptr(std::move(outputs[output_num])), + output); + ++output_num; + } + } + + VLOG(1) << "Done"; +} + +XlaDeviceLaunchOp::~XlaDeviceLaunchOp() { + VLOG(1) << "XlaDeviceLaunch destroyed"; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_launch_op.h b/tensorflow/compiler/jit/xla_device_launch_op.h new file mode 100644 index 0000000000..fbb9319b84 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_launch_op.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'TlaJit' to compile and execute code which is specific +// to the shapes of input Tensors. +class XlaDeviceLaunchOp : public OpKernel { + public: + explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx); + ~XlaDeviceLaunchOp() override; + + void Compute(OpKernelContext* ctx) override; + + private: + NameAttrList function_; + int num_constant_args_; + Tensor dummy_tensor_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc new file mode 100644 index 0000000000..74c314c8ed --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_device_ops.h" + +#include "tensorflow/compiler/jit/xla_device_context.h" + +namespace tensorflow { + +void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs, + const Tensor& rhs) { + std::shared_ptr gd = + XlaTransferManager::GetTensorGlobalData(rhs); + XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs); +} + +XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + +void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { + LOG(FATAL) << "Attempted to execute Op " << name() << "type " << type_string() + << " on an XLA device. This should never happen."; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h new file mode 100644 index 0000000000..1fcb515ddb --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Common kernel registrations for XLA devices. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ + +#include "tensorflow/compiler/jit/xla_device_launch_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/assign_op.h" +#include "tensorflow/core/kernels/constant_op.h" +#include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/identity_op.h" +#include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/variable_ops.h" + +namespace tensorflow { + +// Implementation of Assign for XLA devices. +class XlaDeviceAssignOp : public AssignOp { + public: + using AssignOp::AssignOp; + + void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override; +}; + +// Dummy OpKernel, used for kernels assigned to an XLA device that should be +// compiled. Should never be called at runtime since such ops should be +// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// operator on an XLA device but the compiler did not compile it. +class XlaDeviceDummyOp : public OpKernel { + public: + explicit XlaDeviceDummyOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; +}; + +#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER( \ + Name("_XlaLaunch").Device(DEVICE).HostMemory("constants"), KERNEL); + +#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ + REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \ + REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ + ConstantOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), \ + XlaDeviceDummyOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \ + VariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \ + VariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \ + TemporaryVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ + .Device(DEVICE) \ + .TypeConstraint("T", TYPES), \ + DestroyTemporaryVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ + .Device(DEVICE) \ + .TypeConstraint("dtype", TYPES) \ + .HostMemory("is_initialized"), \ + IsVariableInitializedOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \ + XlaDeviceAssignOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ + SwitchOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + IdentityOp); + +// TODO(phawkins): do we really need Placeholder? Should it be a real +// implementation of Placeholder? + +// TODO(b/32507444): the registrations for the control flow operators are +// temporary and exist primarily to work around a bug in the graph partitioning +// code. + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc new file mode 100644 index 0000000000..731ff7d673 --- /dev/null +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs +// operators using XLA via the XLA "CUDA" (GPU) backend. + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_ops.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +const char* const DEVICE_XLA_GPU = "XLA_GPU"; + +class XlaGpuDeviceFactory : public DeviceFactory { + public: + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector* devices) override; +}; + +Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices) { + static XlaDeviceOpRegistrations* registrations = + RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); + (void)registrations; + + std::unique_ptr device; + Status status = + XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, + name_prefix, &device); + if (!status.ok()) { + // Treat failures as non-fatal; there might not be a GPU in the machine. + LOG(WARNING) << "Failed to create XLA_GPU device: " << status; + return Status::OK(); + } + devices->push_back(device.release()); + return Status::OK(); +} + +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); + +// Kernel registrations + +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; + +REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaDeviceLaunchOp, kAllXlaGpuTypes); +REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_local_launch_op.cc b/tensorflow/compiler/jit/xla_local_launch_op.cc new file mode 100644 index 0000000000..7945e057cf --- /dev/null +++ b/tensorflow/compiler/jit/xla_local_launch_op.cc @@ -0,0 +1,342 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_local_launch_op.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace gpu = perftools::gputools; + +namespace tensorflow { + +REGISTER_OP("_XlaLaunch") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Attr("function: func") + .Doc("XLA Launch Op. For use by the XLA JIT only."); + +// Adapter class that wraps a Tensorflow allocator as an XLA allocator. +class XlaAllocator : public xla::DeviceMemoryAllocator { + public: + XlaAllocator(const perftools::gputools::Platform* platform, + OpKernelContext* op_context); + ~XlaAllocator() override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure = true) override; + Status Deallocate(int device_ordinal, + perftools::gputools::DeviceMemoryBase* mem) override; + + // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is + // interpreted as having data type 'dtype' and shape 'shape'. + Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, + const TensorShape& shape, Tensor* tensor) const; + + private: + OpKernelContext* const op_context_; + + // Map from pointer address to the owning Tensor; used by + // MakeTensorFromBuffer. Also used to automatically release Tensors when the + // allocator is freed. + std::unordered_map tensors_; +}; + +XlaAllocator::XlaAllocator(const perftools::gputools::Platform* platform, + OpKernelContext* op_context) + : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} + +XlaAllocator::~XlaAllocator() = default; + +xla::StatusOr XlaAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + AllocatorAttributes allocator_attrs; + allocator_attrs.set_on_host(false); + + AllocationAttributes allocation_attrs; + allocation_attrs.no_retry_on_failure = !retry_on_failure; + + Tensor t; + Status status = op_context_->allocate_temp( + DT_UINT8, TensorShape({static_cast(size)}), &t, allocator_attrs, + allocation_attrs); + if (!status.ok()) { + VLOG(2) << "Allocation failed " << size; + return status; + } + void* data = + reinterpret_cast(const_cast(t.tensor_data().data())); + TF_RET_CHECK(data != nullptr); + tensors_[data] = t; + return perftools::gputools::DeviceMemoryBase(data, size); +} + +Status XlaAllocator::Deallocate(int device_ordinal, + perftools::gputools::DeviceMemoryBase* mem) { + if (mem->opaque() != nullptr) { + if (tensors_.erase(mem->opaque()) == 0) { + return tensorflow::errors::InvalidArgument("Unknown tensor address"); + } + } + return Status::OK(); +} + +Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, + DataType dtype, + const TensorShape& shape, + Tensor* out_tensor) const { + void* ptr = const_cast(buffer.opaque()); + auto it = tensors_.find(ptr); + if (it == tensors_.end()) { + return errors::InvalidArgument("Unknown tensor address"); + } + const Tensor& tensor = it->second; + + int64 output_size = DataTypeSize(dtype) * shape.num_elements(); + if (tensor.TotalBytes() == output_size) { + out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); + } else { + Tensor slice = tensor.Slice(0, output_size); + out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); + } + return Status::OK(); +} + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : OpKernel(ctx), device_type_(ctx->device_type()) { + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); + function_ = *func; + DataTypeVector constant_types; + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); + num_constant_args_ = constant_types.size(); +} + +Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { + gpu::Platform::Id platform_id; + if (device_type_ == DeviceType(DEVICE_CPU)) { + platform_id = gpu::host::kHostPlatformId; + } else if (device_type_ == DeviceType(DEVICE_GPU)) { + platform_id = gpu::cuda::kCudaPlatformId; + } else { + return errors::InvalidArgument("Unknown device type for local _XlaLaunch"); + } + + auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id); + if (!platform.ok()) { + return StreamExecutorUtil::ConvertStatus(platform.status()); + } + auto client = + xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()); + if (!client.ok()) { + return client.status(); + } + const string* compiler_device; + if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device, + /*requires_jit=*/nullptr)) { + return errors::InvalidArgument("No JIT device registered for ", + device_type_.type()); + } + XlaCompiler::Options options; + options.device_type = DeviceType(*compiler_device); + options.client = client.ValueOrDie(); + options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId); + options.local_executable_has_hybrid_result = true; + *compiler = new XlaCompilationCache(options); + return Status::OK(); +} + +void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOp::Compute " + << Canonicalize(function_.name(), function_.attr()); + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + + gpu::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + XlaCompilationCache* compiler; + OP_REQUIRES_OK(ctx, + rm->LookupOrCreate( + rm->default_container(), "xla_compiler", &compiler, + [this](XlaCompilationCache** compiler) { + return BuildCompilationCache(compiler); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref compiler_ref(compiler); + + xla::LocalClient* client = static_cast(compiler->client()); + + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + OP_REQUIRES_OK(ctx, + compiler->Compile(function_, num_constant_args_, ctx, &kernel, + &executable)); + + VLOG(1) << "Executing XLA Computation..."; + + // Builds an XLA allocator for the device. + XlaAllocator xla_allocator(client->platform(), ctx); + XlaLocalRuntimeContext local_runtime_context; + + std::unique_ptr output; + bool output_is_tuple; + if (!kernel->computation.IsNull()) { + // Build xla::ShapedBuffers that point directly to the Tensor buffers. + std::vector> arg_buffers; + arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); + arg_buffers.resize(kernel->xla_input_shapes.size()); + std::vector arg_ptrs(arg_buffers.size()); + + // Pass remaining parameters. + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int arg_num = kernel->xla_input_shapes[i].first; + const xla::Shape& shape = kernel->xla_input_shapes[i].second; + gpu::DeviceMemoryBase dmem( + const_cast(ctx->input(arg_num).tensor_data().data()), + ctx->input(arg_num).tensor_data().size()); + + arg_buffers[i] = + xla::ShapedBuffer::MakeArrayShapedBuffer( + shape, client->platform(), client->default_device_ordinal(), dmem) + .ConsumeValueOrDie(); + arg_ptrs[i] = arg_buffers[i].get(); + } + + // Make the final parameter point at local_runtime_context. + if (kernel->requires_runtime_context) { + gpu::DeviceMemoryBase local_runtime_context_dmem( + &local_runtime_context, sizeof(local_runtime_context)); + arg_buffers.push_back( + xla::ShapedBuffer::MakeArrayShapedBuffer( + xla::ShapeUtil::MakeOpaqueShape(), client->platform(), + client->default_device_ordinal(), local_runtime_context_dmem) + .ConsumeValueOrDie()); + arg_ptrs.push_back(arg_buffers.back().get()); + } + + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(&xla_allocator); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + auto run_result = executable->Run(arg_ptrs, run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + if (local_runtime_context.error) { + ctx->CtxFailure(errors::InvalidArgument( + "Compiled kernel returned error: ", local_runtime_context.error_msg)); + return; + } + + output = std::move(run_result.ValueOrDie()); + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + // Computation output should always be a tuple. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); + } + output_is_tuple = xla::ShapeUtil::IsTuple(output->shape()); + } + CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + + // Copy XLA results to the OpOutputList. + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + if (kernel->outputs[i].is_constant) { + // Output is a constant + const Tensor& const_tensor = kernel->outputs[i].constant_value; + const size_t total_bytes = const_tensor.TotalBytes(); + if (stream && total_bytes > 0) { + // Copy host -> device. (Empty tensors don't have backing buffers.) + VLOG(1) << "Constant output tensor on device"; + Tensor* output_tensor; + TF_CHECK_OK( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + + const void* src_ptr = DMAHelper::base(&const_tensor); + void* dst_ptr = DMAHelper::base(output_tensor); + gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); + } else { + // No copy required. + ctx->set_output(i, const_tensor); + } + } else { + const TensorShape& shape = kernel->outputs[i].shape; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); + + gpu::DeviceMemoryBase buffer; + if (output_is_tuple) { + buffer = output->buffer({output_num}); + } else { + CHECK_EQ(0, output_num); + buffer = output->buffer({}); + } + Tensor output_tensor; + // Looks up the owning Tensor by buffer address. + OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( + buffer, ctx->expected_output_dtype(i), shape, + &output_tensor)); + ctx->set_output(i, output_tensor); + ++output_num; + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << ctx->mutable_output(i)->DebugString(); + } + } + + VLOG(1) << "Done"; +} + +XlaLocalLaunchOp::~XlaLocalLaunchOp() { + VLOG(1) << "XlaLocalLaunchOp destroyed"; +} + +REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), + XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER( + Name("_XlaLaunch").Device(DEVICE_GPU).HostMemory("constants"), + XlaLocalLaunchOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_local_launch_op.h b/tensorflow/compiler/jit/xla_local_launch_op.h new file mode 100644 index 0000000000..96ae664cbe --- /dev/null +++ b/tensorflow/compiler/jit/xla_local_launch_op.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaLocalLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'XlaCompilationCache' to compile and execute code which is specific +// to the shapes of input Tensors. +// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes +// arguments into/out of XLA in device memory. +class XlaLocalLaunchOp : public OpKernel { + public: + explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); + ~XlaLocalLaunchOp() override; + + void Compute(OpKernelContext* ctx) override; + + private: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(XlaCompilationCache** compiler); + + DeviceType device_type_; + NameAttrList function_; + int num_constant_args_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD new file mode 100644 index 0000000000..b4f01de4f2 --- /dev/null +++ b/tensorflow/compiler/tests/BUILD @@ -0,0 +1,352 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], +) + +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") +load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites") + +generate_backend_suites() + +py_library( + name = "xla_test", + testonly = 1, + srcs = ["xla_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:variables", + ], +) + +cc_library( + name = "depthwise_conv2d_test_kernel", + testonly = 1, + srcs = ["depthwise_conv2d_test_kernel.cc"], + deps = ["//tensorflow/core:framework_lite"], +) + +tf_xla_py_test( + name = "binary_ops_test", + size = "small", + srcs = ["binary_ops_test.py"], + shard_count = 5, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "clustering_test", + size = "small", + srcs = ["clustering_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "concat_ops_test", + size = "small", + srcs = ["concat_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:array_ops_gen", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradient_checker", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "conv2d_test", + size = "medium", + srcs = ["conv2d_test.py"], + shard_count = 10, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "dynamic_stitch_test", + size = "small", + srcs = ["dynamic_stitch_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "function_test", + size = "small", + srcs = ["function_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "lrn_ops_test", + size = "medium", + srcs = ["lrn_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "nary_ops_test", + size = "small", + srcs = ["nary_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "nullary_ops_test", + size = "small", + srcs = ["nullary_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "pooling_ops_test", + size = "medium", + srcs = ["pooling_ops_test.py"], + shard_count = 10, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "reduce_ops_test", + size = "medium", + srcs = ["reduce_ops_test.py"], + shard_count = 5, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "ternary_ops_test", + size = "small", + srcs = ["ternary_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "unary_ops_test", + size = "small", + srcs = ["unary_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "xla_device_test", + size = "small", + srcs = ["xla_device_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + ], +) + +cuda_py_test( + name = "jit_test", + size = "medium", + srcs = ["jit_test.py"], + additional_deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + +cc_library( + name = "randomized_tests_library", + testonly = 1, + srcs = ["randomized_tests.cc"], + deps = [ + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cuda_cc_test( + name = "randomized_tests", + # This test is randomized, so only run it if explicitly requested. + tags = [ + "manual", + "noguitar", + "notap", + ], + deps = [":randomized_tests_library"], +) + +py_library( + name = "lstm", + testonly = 1, + srcs = ["lstm.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( + name = "lstm_test", + srcs = ["lstm_test.py"], + additional_deps = [ + ":lstm", + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:variables", + ], +) + +# An example of ahead-of-time compilation using tfcompile. The +# lstm_layer_inference.pbtxt file was generated by running lstm_test +# --dump_graph_dir, and the config file was written by hand. +# +# Run the following to build a minimal benchmark of the computation on Android: +# $ bazel build -c opt --config=android_arm \ +# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark +# +# Currently the resulting binary size is ~190KB +tf_library( + name = "lstm_layer_inference", + testonly = 1, + config = "lstm_layer_inference.config.pbtxt", + cpp_class = "LSTMLayerInference", + graph = "lstm_layer_inference.pbtxt", + tags = ["manual"], + tfcompile_flags = "--xla_cpu_multi_thread_eigen=false", +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py new file mode 100644 index 0000000000..9d197b0646 --- /dev/null +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -0,0 +1,749 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for binary operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class BinaryOpsTest(XLATestCase): + """Test cases for binary operators.""" + + def _testBinary(self, op, a, b, expected, equality_test=None): + with self.test_session() as session: + with self.test_scope(): + pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") + pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") + output = op(pa, pb) + result = session.run(output, {pa: a, pb: b}) + if equality_test is None: + equality_test = self.assertAllClose + equality_test(result, expected, rtol=1e-3) + + def ListsAreClose(self, result, expected, rtol): + """Tests closeness of two lists of floats.""" + self.assertEqual(len(result), len(expected)) + for i in range(len(result)): + self.assertAllClose(result[i], expected[i], rtol) + + def testFloatOps(self): + for dtype in self.float_types: + self._testBinary( + gen_math_ops._real_div, + np.array([3, 3, -1.5, -8, 44], dtype=dtype), + np.array([2, -2, 7, -4, 0], dtype=dtype), + expected=np.array( + [1.5, -1.5, -0.2142857, 2, float("inf")], dtype=dtype)) + + self._testBinary(math_ops.pow, dtype(3), dtype(4), expected=dtype(81)) + + self._testBinary( + math_ops.pow, + np.array([1, 2], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( + math_ops.pow, + np.array([10, 4], dtype=dtype), + np.array([2, 3], dtype=dtype), + expected=np.array([100, 64], dtype=dtype)) + self._testBinary( + math_ops.pow, + dtype(2), + np.array([3, 4], dtype=dtype), + expected=np.array([8, 16], dtype=dtype)) + self._testBinary( + math_ops.pow, + np.array([[2], [3]], dtype=dtype), + dtype(4), + expected=np.array([[16], [81]], dtype=dtype)) + + self._testBinary( + gen_math_ops._sigmoid_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([-60, -36, -14, 0], dtype=dtype)) + + self._testBinary( + gen_math_ops._rsqrt_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([-160, -81, -28, -4], dtype=dtype)) + + self._testBinary( + gen_nn_ops._softplus_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array( + [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype)) + + self._testBinary( + gen_math_ops._tanh_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([-75, -48, -21, 0], dtype=dtype)) + + self._testBinary( + gen_nn_ops._relu_grad, + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), + expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype)) + + self._testBinary( + gen_nn_ops._relu6_grad, + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype), + np.array( + [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), + expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) + + self._testBinary( + gen_nn_ops._softmax_cross_entropy_with_logits, + np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), + np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype), + expected=[ + np.array([1.44019, 2.44019], dtype=dtype), + np.array([[-0.067941, -0.112856, -0.063117, 0.243914], + [-0.367941, -0.212856, 0.036883, 0.543914]], + dtype=dtype), + ], + equality_test=self.ListsAreClose) + + def testIntOps(self): + for dtype in self.int_types: + self._testBinary( + gen_math_ops._truncate_div, + np.array([3, 3, -1, -9, -8], dtype=dtype), + np.array([2, -2, 7, 2, -4], dtype=dtype), + expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) + + def testNumericOps(self): + for dtype in self.numeric_types: + self._testBinary( + math_ops.add, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([11, 22], dtype=dtype)) + self._testBinary( + math_ops.add, + dtype(5), + np.array([1, 2], dtype=dtype), + expected=np.array([6, 7], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([[1], [2]], dtype=dtype), + dtype(7), + expected=np.array([[8], [9]], dtype=dtype)) + + self._testBinary( + math_ops.sub, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([-9, -18], dtype=dtype)) + self._testBinary( + math_ops.sub, + dtype(5), + np.array([1, 2], dtype=dtype), + expected=np.array([4, 3], dtype=dtype)) + self._testBinary( + math_ops.sub, + np.array([[1], [2]], dtype=dtype), + dtype(7), + expected=np.array([[-6], [-5]], dtype=dtype)) + + self._testBinary( + math_ops.maximum, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([10, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([5, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[10], [7]], dtype=dtype)) + + self._testBinary( + math_ops.minimum, + np.array([1, 20], dtype=dtype), + np.array([10, 2], dtype=dtype), + expected=np.array([1, 2], dtype=dtype)) + self._testBinary( + math_ops.minimum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([1, 5], dtype=dtype)) + self._testBinary( + math_ops.minimum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[7], [2]], dtype=dtype)) + + self._testBinary( + math_ops.mul, + np.array([1, 20], dtype=dtype), + np.array([10, 2], dtype=dtype), + expected=np.array([10, 40], dtype=dtype)) + self._testBinary( + math_ops.mul, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([5, 100], dtype=dtype)) + self._testBinary( + math_ops.mul, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[70], [14]], dtype=dtype)) + + self._testBinary( + math_ops.squared_difference, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([81, 324], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + dtype(5), + np.array([1, 2], dtype=dtype), + expected=np.array([16, 9], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + np.array([[1], [2]], dtype=dtype), + dtype(7), + expected=np.array([[36], [25]], dtype=dtype)) + + self._testBinary( + nn_ops.bias_add, + np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([2, -1], dtype=dtype), + expected=np.array([[3, 1], [5, 3]], dtype=dtype)) + self._testBinary( + nn_ops.bias_add, + np.array([[[[1, 2], [3, 4]]]], dtype=dtype), + np.array([2, -1], dtype=dtype), + expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + + def _testDivision(self, dtype): + """Test cases for division operators.""" + self._testBinary( + math_ops.div, + np.array([10, 20], dtype=dtype), + np.array([10, 2], dtype=dtype), + expected=np.array([1, 10], dtype=dtype)) + self._testBinary( + math_ops.div, + dtype(40), + np.array([2, 20], dtype=dtype), + expected=np.array([20, 2], dtype=dtype)) + self._testBinary( + math_ops.div, + np.array([[10], [4]], dtype=dtype), + dtype(2), + expected=np.array([[5], [2]], dtype=dtype)) + + self._testBinary( + gen_math_ops._floor_div, + np.array([3, 3, -1, -9, -8], dtype=dtype), + np.array([2, -2, 7, 2, -4], dtype=dtype), + expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) + + def testIntDivision(self): + for dtype in self.int_types: + self._testDivision(dtype) + + def testFloatDivision(self): + for dtype in self.float_types: + self._testDivision(dtype) + + def _testRemainder(self, dtype): + """Test cases for remainder operators.""" + self._testBinary( + gen_math_ops._floor_mod, + np.array([3, 3, -1, -8], dtype=dtype), + np.array([2, -2, 7, -4], dtype=dtype), + expected=np.array([1, -1, 6, 0], dtype=dtype)) + self._testBinary( + gen_math_ops._truncate_mod, + np.array([3, 3, -1, -8], dtype=dtype), + np.array([2, -2, 7, -4], dtype=dtype), + expected=np.array([1, 1, -1, 0], dtype=dtype)) + + def testIntRemainder(self): + for dtype in self.int_types: + self._testRemainder(dtype) + + def testFloatRemainder(self): + for dtype in self.float_types: + self._testRemainder(dtype) + + def testLogicalOps(self): + self._testBinary( + math_ops.logical_and, + np.array([[True, False], [False, True]], dtype=np.bool), + np.array([[False, True], [False, True]], dtype=np.bool), + expected=np.array([[False, False], [False, True]], dtype=np.bool)) + + self._testBinary( + math_ops.logical_or, + np.array([[True, False], [False, True]], dtype=np.bool), + np.array([[False, True], [False, True]], dtype=np.bool), + expected=np.array([[True, True], [False, True]], dtype=np.bool)) + + def testComparisons(self): + self._testBinary( + math_ops.equal, + np.array([1, 5, 20], dtype=np.float32), + np.array([10, 5, 2], dtype=np.float32), + expected=np.array([False, True, False], dtype=np.bool)) + self._testBinary( + math_ops.equal, + np.float32(5), + np.array([1, 5, 20], dtype=np.float32), + expected=np.array([False, True, False], dtype=np.bool)) + self._testBinary( + math_ops.equal, + np.array([[10], [7], [2]], dtype=np.float32), + np.float32(7), + expected=np.array([[False], [True], [False]], dtype=np.bool)) + + self._testBinary( + math_ops.not_equal, + np.array([1, 5, 20], dtype=np.float32), + np.array([10, 5, 2], dtype=np.float32), + expected=np.array([True, False, True], dtype=np.bool)) + self._testBinary( + math_ops.not_equal, + np.float32(5), + np.array([1, 5, 20], dtype=np.float32), + expected=np.array([True, False, True], dtype=np.bool)) + self._testBinary( + math_ops.not_equal, + np.array([[10], [7], [2]], dtype=np.float32), + np.float32(7), + expected=np.array([[True], [False], [True]], dtype=np.bool)) + + for greater_op in [math_ops.greater, (lambda x, y: x > y)]: + self._testBinary( + greater_op, + np.array([1, 5, 20], dtype=np.float32), + np.array([10, 5, 2], dtype=np.float32), + expected=np.array([False, False, True], dtype=np.bool)) + self._testBinary( + greater_op, + np.float32(5), + np.array([1, 5, 20], dtype=np.float32), + expected=np.array([True, False, False], dtype=np.bool)) + self._testBinary( + greater_op, + np.array([[10], [7], [2]], dtype=np.float32), + np.float32(7), + expected=np.array([[True], [False], [False]], dtype=np.bool)) + + for greater_equal_op in [math_ops.greater_equal, (lambda x, y: x >= y)]: + self._testBinary( + greater_equal_op, + np.array([1, 5, 20], dtype=np.float32), + np.array([10, 5, 2], dtype=np.float32), + expected=np.array([False, True, True], dtype=np.bool)) + self._testBinary( + greater_equal_op, + np.float32(5), + np.array([1, 5, 20], dtype=np.float32), + expected=np.array([True, True, False], dtype=np.bool)) + self._testBinary( + greater_equal_op, + np.array([[10], [7], [2]], dtype=np.float32), + np.float32(7), + expected=np.array([[True], [True], [False]], dtype=np.bool)) + + for less_op in [math_ops.less, (lambda x, y: x < y)]: + self._testBinary( + less_op, + np.array([1, 5, 20], dtype=np.float32), + np.array([10, 5, 2], dtype=np.float32), + expected=np.array([True, False, False], dtype=np.bool)) + self._testBinary( + less_op, + np.float32(5), + np.array([1, 5, 20], dtype=np.float32), + expected=np.array([False, False, True], dtype=np.bool)) + self._testBinary( + less_op, + np.array([[10], [7], [2]], dtype=np.float32), + np.float32(7), + expected=np.array([[False], [False], [True]], dtype=np.bool)) + + for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: + self._testBinary( + less_equal_op, + np.array([1, 5, 20], dtype=np.float32), + np.array([10, 5, 2], dtype=np.float32), + expected=np.array([True, True, False], dtype=np.bool)) + self._testBinary( + less_equal_op, + np.float32(5), + np.array([1, 5, 20], dtype=np.float32), + expected=np.array([False, True, True], dtype=np.bool)) + self._testBinary( + less_equal_op, + np.array([[10], [7], [2]], dtype=np.float32), + np.float32(7), + expected=np.array([[False], [True], [True]], dtype=np.bool)) + + def testBroadcasting(self): + """Tests broadcasting behavior of an operator.""" + + for dtype in self.numeric_types: + self._testBinary( + math_ops.add, + np.array(3, dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([13, 23], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([10, 20], dtype=dtype), + np.array(4, dtype=dtype), + expected=np.array([14, 24], dtype=dtype)) + + # [1,3] x [4,1] => [4,3] + self._testBinary( + math_ops.add, + np.array([[10, 20, 30]], dtype=dtype), + np.array([[1], [2], [3], [4]], dtype=dtype), + expected=np.array( + [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]], + dtype=dtype)) + + # [3] * [4,1] => [4,3] + self._testBinary( + math_ops.add, + np.array([10, 20, 30], dtype=dtype), + np.array([[1], [2], [3], [4]], dtype=dtype), + expected=np.array( + [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]], + dtype=dtype)) + + def testFill(self): + for dtype in self.numeric_types: + self._testBinary( + array_ops.fill, + np.array([], dtype=np.int32), + dtype(-42), + expected=dtype(-42)) + self._testBinary( + array_ops.fill, + np.array([1, 2], dtype=np.int32), + dtype(7), + expected=np.array([[7, 7]], dtype=dtype)) + self._testBinary( + array_ops.fill, + np.array([3, 2], dtype=np.int32), + dtype(50), + expected=np.array([[50, 50], [50, 50], [50, 50]], dtype=dtype)) + + # Helper method used by testMatMul, testSparseMatMul, testBatchMatMul below. + def _testMatMul(self, op): + for dtype in self.float_types: + self._testBinary( + op, + np.array([[-0.25]], dtype=dtype), + np.array([[8]], dtype=dtype), + expected=np.array([[-2]], dtype=dtype)) + self._testBinary( + op, + np.array([[100, 10, 0.5]], dtype=dtype), + np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype), + expected=np.array([[123, 354]], dtype=dtype)) + self._testBinary( + op, + np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype), + np.array([[100], [10]], dtype=dtype), + expected=np.array([[130], [250], [680]], dtype=dtype)) + self._testBinary( + op, + np.array([[1000, 100], [10, 1]], dtype=dtype), + np.array([[1, 2], [3, 4]], dtype=dtype), + expected=np.array([[1300, 2400], [13, 24]], dtype=dtype)) + + self._testBinary( + op, + np.array([], dtype=dtype).reshape((2, 0)), + np.array([], dtype=dtype).reshape((0, 3)), + expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype)) + + def testMatMul(self): + self._testMatMul(math_ops.matmul) + + # TODO(phawkins): failing on GPU, no registered kernel. + def DISABLED_testSparseMatMul(self): + # Binary wrappers for sparse_matmul with different hints + def SparseMatmulWrapperTF(a, b): + return tf.sparse_matmul(a, b, a_is_sparse=True) + + def SparseMatmulWrapperFT(a, b): + return tf.sparse_matmul(a, b, b_is_sparse=True) + + def SparseMatmulWrapperTT(a, b): + return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) + + self._testMatMul(tf.sparse_matmul) + self._testMatMul(SparseMatmulWrapperTF) + self._testMatMul(SparseMatmulWrapperFT) + self._testMatMul(SparseMatmulWrapperTT) + + def testBatchMatMul(self): + # Same tests as for tf.matmul above. + self._testMatMul(math_ops.matmul) + + # Tests with batches of matrices. + self._testBinary( + math_ops.matmul, + np.array([[[-0.25]]], dtype=np.float32), + np.array([[[8]]], dtype=np.float32), + expected=np.array([[[-2]]], dtype=np.float32)) + self._testBinary( + math_ops.matmul, + np.array([[[-0.25]], [[4]]], dtype=np.float32), + np.array([[[8]], [[2]]], dtype=np.float32), + expected=np.array([[[-2]], [[8]]], dtype=np.float32)) + self._testBinary( + math_ops.matmul, + np.array( + [[[[1000, 100], [10, 1]], [[2000, 200], [20, 2]]], + [[[3000, 300], [30, 3]], [[4000, 400], [40, 4]]]], + dtype=np.float32), + np.array( + [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]], + [[55, 66], [77, 88]]]], + dtype=np.float32), + expected=np.array( + [[[[1300, 2400], [13, 24]], [[11400, 13600], [114, 136]]], + [[[42900, 79200], [429, 792]], [[250800, 299200], [2508, 2992]]]], + dtype=np.float32)) + self._testBinary( + math_ops.matmul, + np.array([], dtype=np.float32).reshape((2, 2, 0)), + np.array([], dtype=np.float32).reshape((2, 0, 3)), + expected=np.array( + [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]], + dtype=np.float32)) + self._testBinary( + math_ops.matmul, + np.array([], dtype=np.float32).reshape((0, 2, 4)), + np.array([], dtype=np.float32).reshape((0, 4, 3)), + expected=np.array([], dtype=np.float32).reshape(0, 2, 3)) + + # Regression test for b/31472796. + if hasattr(np, "matmul"): + x = np.arange(0, 3 * 5 * 16 * 7, dtype=np.float32).reshape((3, 5, 16, 7)) + self._testBinary( + lambda x, y: math_ops.matmul(x, y, adjoint_b=True), + x, x, + expected=np.matmul(x, x.transpose([0, 1, 3, 2]))) + + def testExpandDims(self): + for dtype in self.numeric_types: + self._testBinary( + array_ops.expand_dims, + dtype(7), + np.int32(0), + expected=np.array([7], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([42], dtype=dtype), + np.int32(0), + expected=np.array([[42]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([], dtype=dtype), + np.int32(0), + expected=np.array([[]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.int32(0), + expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.int32(1), + expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.int32(2), + expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.int32(3), + expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + self._testBinary( + array_ops.pad, + np.array( + [[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array( + [[1, 2], [2, 1]], dtype=np.int32), + expected=np.array( + [[0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 3, 0], + [0, 0, 4, 5, 6, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + dtype=dtype)) + + def testReshape(self): + for dtype in self.numeric_types: + self._testBinary( + array_ops.reshape, + np.array([], dtype=dtype), + np.array([0, 4], dtype=np.int32), + expected=np.zeros(shape=[0, 4], dtype=dtype)) + self._testBinary( + array_ops.reshape, + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([2, 3], dtype=np.int32), + expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype)) + self._testBinary( + array_ops.reshape, + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([3, 2], dtype=np.int32), + expected=np.array([[0, 1], [2, 3], [4, 5]], dtype=dtype)) + self._testBinary( + array_ops.reshape, + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([-1, 6], dtype=np.int32), + expected=np.array([[0, 1, 2, 3, 4, 5]], dtype=dtype)) + self._testBinary( + array_ops.reshape, + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([6, -1], dtype=np.int32), + expected=np.array([[0], [1], [2], [3], [4], [5]], dtype=dtype)) + self._testBinary( + array_ops.reshape, + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([2, -1], dtype=np.int32), + expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype)) + self._testBinary( + array_ops.reshape, + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([-1, 3], dtype=np.int32), + expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype)) + + def testSplit(self): + for dtype in self.numeric_types: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), + np.int32(0), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1], [2]]], dtype=dtype), + np.array([[[3], [4]]], dtype=dtype), + np.array([[[5], [6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), + np.int32(1), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1]], [[3]], [[5]]], dtype=dtype), + np.array([[[2]], [[4]], [[6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + + def testTile(self): + for dtype in self.numeric_types: + self._testBinary( + array_ops.tile, + np.array([[6]], dtype=dtype), + np.array([1, 2], dtype=np.int32), + expected=np.array([[6, 6]], dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[1], [2]], dtype=dtype), + np.array([1, 2], dtype=np.int32), + expected=np.array([[1, 1], [2, 2]], dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([3, 2], dtype=np.int32), + expected=np.array( + [[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]], + dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([1, 1], dtype=np.int32), + expected=np.array( + [[1, 2], + [3, 4]], + dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[1, 2]], dtype=dtype), + np.array([3, 1], dtype=np.int32), + expected=np.array( + [[1, 2], + [1, 2], + [1, 2]], + dtype=dtype)) + + def testTranspose(self): + for dtype in self.numeric_types: + self._testBinary( + array_ops.transpose, + np.zeros(shape=[1, 0, 4], dtype=dtype), + np.array([1, 2, 0], dtype=np.int32), + expected=np.zeros(shape=[0, 4, 1], dtype=dtype)) + self._testBinary( + array_ops.transpose, + np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([0, 1], dtype=np.int32), + expected=np.array([[1, 2], [3, 4]], dtype=dtype)) + self._testBinary( + array_ops.transpose, + np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([1, 0], dtype=np.int32), + expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl new file mode 100644 index 0000000000..7fb8e0a26d --- /dev/null +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -0,0 +1,78 @@ +"""Build rules for Tensorflow/XLA testing.""" + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") + +def all_backends(): + if cuda_is_configured(): + return ["cpu", "gpu"] + else: + return ["cpu"] + +def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, + backends=None, **kwargs): + """Generates py_test targets, one per XLA backend. + + This rule generates py_test() targets named name_backend, for each backend + in all_backends(). The rule also generates a test suite with named `name` that + tests all backends for the test. + + For example, the following rule generates test cases foo_test_cpu, + foo_test_gpu, and a test suite name foo_test that tests both. + tf_xla_py_test( + name="foo_test", + srcs="foo_test.py", + deps=[...], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + tags: Tags to apply to the generated targets. + data: Data dependencies of the target. + main: Same as py_test's main attribute. + backends: A list of backends to test. Supported values include "cpu" and + "gpu". If not specified, defaults to all backends. + **kwargs: keyword arguments passed onto the generated py_test() rules. + """ + if backends == None: + backends = all_backends() + + test_names = [] + for backend in backends: + test_name = "{}_{}".format(name, backend) + backend_tags = ["tf_xla_{}".format(backend)] + backend_args = [] + backend_deps = [] + backend_data = [] + if backend == "cpu": + backend_args += ["--test_device=XLA_CPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + elif backend == "gpu": + backend_args += ["--test_device=XLA_GPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + backend_tags += ["requires-gpu-sm35"] + else: + fail("Unknown backend {}".format(backend)) + + native.py_test( + name=test_name, + srcs=srcs, + srcs_version="PY2AND3", + args=backend_args, + main="{}.py".format(name) if main == None else main, + data=data + backend_data, + deps=deps + backend_deps, + tags=tags + backend_tags, + **kwargs + ) + test_names.append(test_name) + native.test_suite(name=name, tests=test_names) + +def generate_backend_suites(backends=[]): + """Generates per-backend test_suites that run all tests for a backend.""" + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) + diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py new file mode 100644 index 0000000000..6993648531 --- /dev/null +++ b/tensorflow/compiler/tests/clustering_test.py @@ -0,0 +1,102 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the behavior of the auto-compilation pass.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + +CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" + + +class ClusteringTest(XLATestCase): + + def testAdd(self): + val1 = np.array([4, 3, 2, 1], dtype=np.float32) + val2 = np.array([5, 6, 7, 8], dtype=np.float32) + expected = val1 + val2 + with self.test_session(): + with self.test_scope(): + input1 = constant_op.constant(val1, name="const1") + input2 = constant_op.constant(val2, name="const2") + output = math_ops.add(input1, input2) + result = output.eval() + self.assertAllClose(result, expected, rtol=1e-3) + + def testAddFromCpuMultiple(self): + val1 = np.array([4, 3, 2, 1]).astype(np.float32) + val2 = np.array([5, 6, 7, 8]).astype(np.float32) + expected = val1 + val2 + with self.test_session(): + with ops.device(CPU_DEVICE): + input1 = constant_op.constant(val1, name="const1") + input2 = constant_op.constant(val2, name="const2") + with self.test_scope(): + output = math_ops.add(input1, input2) + for _ in xrange(10): + result = output.eval() + self.assertAllClose(result, expected, rtol=1e-3) + + def testDeadlock(self): + # Builds a graph of the form: + # x -> y + # | \ + # z -> w + # where x and z are placed on the CPU and y and w are placed on the XLA + # device. If y and w are clustered for compilation, then the graph will + # deadlock since the clustered graph will contain a self-loop. + with self.test_session() as sess: + with ops.device(CPU_DEVICE): + x = array_ops.placeholder(dtypes.float32, [2]) + with self.test_scope(): + y = x * 2 + with ops.device(CPU_DEVICE): + z = y * y + with self.test_scope(): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + def testHostMemory(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + y = x + 1 + with ops.device(CPU_DEVICE): + # Place a computation on the CPU, so y and w cannot be merged into the + # same JIT compilation. + z = y * 2 + with self.test_scope(): + # Argument 'y' is a non-constant output of a previous cluster. Make sure + # it is properly copied to host memory so it can be used as a + # compile-time constant input for this cluster. + w = array_ops.reshape(z, y) + result = sess.run(w, {x: [1, 0]}) + expected = np.array([[4], [2]], dtype=np.int32) + self.assertAllClose(expected, result, rtol=1e-3) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py new file mode 100644 index 0000000000..6f74ae702c --- /dev/null +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -0,0 +1,374 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for XLA Concat Op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class ConcatTest(XLATestCase): + + def testHStack(self): + with self.test_session(): + p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) + p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) + with self.test_scope(): + c = array_ops.concat_v2([p1, p2], 0) + params = { + p1: np.random.rand(4, 4).astype("f"), + p2: np.random.rand(4, 4).astype("f") + } + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + self.assertAllEqual(result[:4, :], params[p1]) + self.assertAllEqual(result[4:, :], params[p2]) + + def testVStack(self): + with self.test_session(): + p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) + p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) + with self.test_scope(): + c = array_ops.concat_v2([p1, p2], 1) + params = { + p1: np.random.rand(4, 4).astype("f"), + p2: np.random.rand(4, 4).astype("f") + } + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + self.assertAllEqual(result[:, :4], params[p1]) + self.assertAllEqual(result[:, 4:], params[p2]) + + def testInt32(self): + with self.test_session(): + p1 = np.random.rand(2, 3).astype("i") + p2 = np.random.rand(2, 3).astype("i") + x1 = constant_op.constant(p1) + x2 = constant_op.constant(p2) + with self.test_scope(): + c = array_ops.concat_v2([x1, x2], 0) + result = c.eval() + self.assertAllEqual(result[:2, :], p1) + self.assertAllEqual(result[2:, :], p2) + + def _testRandom(self, dtype): + # Random dims of rank 5 + shape = np.random.randint(1, 5, size=5) + # Random number of tensors, but always > 1. + num_tensors = np.random.randint(2, 10) + # Random dim to concat on + concat_dim = np.random.randint(5) + params = {} + if dtype == dtypes.bfloat16: + dtype_feed = dtypes.float32 + else: + dtype_feed = dtype + with self.test_session(): + p = [] + for i in np.arange(num_tensors): + input_shape = shape + input_shape[concat_dim] = np.random.randint(1, 5) + placeholder = array_ops.placeholder(dtype_feed, shape=input_shape) + p.append(placeholder) + + t = dtype_feed.as_numpy_dtype + params[placeholder] = np.random.rand(*input_shape).astype(t) + + if dtype != dtype_feed: + concat_inputs = [math_ops.cast(p_i, dtype) for p_i in p] + else: + concat_inputs = p + with self.test_scope(): + c = array_ops.concat_v2(concat_inputs, concat_dim) + if dtype != dtype_feed: + c = math_ops.cast(c, dtype_feed) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + ind = [slice(0, params[p[i]].shape[j]) for j in np.arange(5)] + ind[concat_dim] = slice(cur_offset, + cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + if dtype == dtype_feed: + self.assertAllEqual(result[ind], params[p[i]]) + else: + self.assertAllClose(result[ind], params[p[i]], 0.01) + + def testRandom(self): + self._testRandom(dtypes.float32) + self._testRandom(dtypes.int32) + + def _testGradientsSimple(self): + with self.test_session(): + inp = [] + inp_tensors = [] + with self.test_scope(): + for x in [1, 2, 6]: + shape = [10, x, 2] + t = np.random.rand(*shape).astype("f") + inp.append(t) + inp_tensors.append( + constant_op.constant( + [float(y) for y in t.flatten()], + shape=shape, + dtype=dtypes.float32)) + c = array_ops.concat_v2(inp_tensors, 1) + output_shape = [10, 9, 2] + grad_inp = np.random.rand(*output_shape).astype("f") + grad_tensor = constant_op.constant( + [float(x) for x in grad_inp.flatten()], shape=output_shape) + grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) + concated_grad = array_ops.concat_v2(grad, 1) + result = concated_grad.eval() + self.assertAllEqual(result, grad_inp) + + def testGradientsSimpleAll(self): + self._testGradientsSimple() + + def _testGradientsFirstDim(self): + with self.test_session(): + inp = [] + inp_tensors = [] + with self.test_scope(): + for x in [1, 2, 6]: + shape = [x, 10, 2] + t = np.random.rand(*shape).astype("f") + inp.append(t) + inp_tensors.append( + constant_op.constant( + [float(y) for y in t.flatten()], + shape=shape, + dtype=dtypes.float32)) + c = array_ops.concat_v2(inp_tensors, 0) + output_shape = [9, 10, 2] + grad_inp = np.random.rand(*output_shape).astype("f") + grad_tensor = constant_op.constant( + [float(x) for x in grad_inp.flatten()], shape=output_shape) + grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) + concated_grad = array_ops.concat_v2(grad, 0) + result = concated_grad.eval() + + self.assertAllEqual(result, grad_inp) + + def testGradientsFirstDimAll(self): + self._testGradientsFirstDim() + + def _testGradientsLastDim(self): + with self.test_session(): + inp = [] + inp_tensors = [] + with self.test_scope(): + for x in [1, 2, 6]: + shape = [10, 2, x] + t = np.random.rand(*shape).astype("f") + inp.append(t) + inp_tensors.append( + constant_op.constant( + [float(y) for y in t.flatten()], + shape=shape, + dtype=dtypes.float32)) + c = array_ops.concat_v2(inp_tensors, 2) + output_shape = [10, 2, 9] + grad_inp = np.random.rand(*output_shape).astype("f") + grad_tensor = constant_op.constant( + [float(x) for x in grad_inp.flatten()], shape=output_shape) + grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) + concated_grad = array_ops.concat_v2(grad, 2) + result = concated_grad.eval() + + self.assertAllEqual(result, grad_inp) + + def testGradientsLastDimAll(self): + self._testGradientsLastDim() + + def _RunAndVerifyGradientsRandom(self): + # Random dims of rank 5 + input_shape = np.random.randint(1, 5, size=5) + # Random number of tensors + num_tensors = np.random.randint(1, 10) + # Random dim to concat on + concat_dim = np.random.randint(5) + concat_dim_sizes = np.random.randint(1, 5, size=num_tensors) + with self.test_session(): + inp = [] + inp_tensors = [] + with self.test_scope(): + for x in concat_dim_sizes: + shape = input_shape + shape[concat_dim] = x + t = np.random.rand(*shape).astype("f") + inp.append(t) + inp_tensors.append( + constant_op.constant( + [float(y) for y in t.flatten()], + shape=shape, + dtype=dtypes.float32)) + c = array_ops.concat_v2(inp_tensors, concat_dim) + output_shape = input_shape + output_shape[concat_dim] = concat_dim_sizes.sum() + grad_inp = np.random.rand(*output_shape).astype("f") + grad_tensor = constant_op.constant( + [float(x) for x in grad_inp.flatten()], shape=output_shape) + grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) + concated_grad = array_ops.concat_v2(grad, concat_dim) + result = concated_grad.eval() + + self.assertAllEqual(result, grad_inp) + + def testGradientsRandom(self): + for _ in range(5): + self._RunAndVerifyGradientsRandom() + + # Re-enable once zero-element Retvals are handled correctly. + def DISABLED_testZeroSize(self): + # Verify that concat doesn't crash and burn for zero size inputs + np.random.seed(7) + with self.test_session() as sess: + with self.test_scope(): + for shape0 in (), (2,): + axis = len(shape0) + for shape1 in (), (3,): + for n0 in 0, 1, 2: + for n1 in 0, 1, 2: + x0 = np.random.randn(*(shape0 + (n0,) + shape1)) + x1 = np.random.randn(*(shape0 + (n1,) + shape1)) + correct = np.concatenate([x0, x1], axis=axis) + # TODO(irving): Make tf.concat handle map, then drop list(). + xs = list(map(constant_op.constant, [x0, x1])) + c = array_ops.concat_v2(xs, axis) + self.assertAllEqual(c.eval(), correct) + # Check gradients + dc = np.random.randn(*c.get_shape().as_list()) + dxs = sess.run(gradients_impl.gradients(c, xs, dc)) + self.assertAllEqual(dc, np.concatenate(dxs, axis=axis)) + + def testTensorConcatDim0Grad(self): + x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]] + output_shape = [44, 7, 3] + x_vals = [ + np.random.random_sample(x_shape).astype(np.float32) + for x_shape in x_shapes + ] + with self.test_session(): + with self.test_scope(): + xs = [constant_op.constant(x_val) for x_val in x_vals] + output = array_ops.concat_v2(xs, 0) + err = gradient_checker.compute_gradient_error(xs, x_shapes, output, + output_shape) + self.assertLess(err, 1e-4) + + def testTensorConcatDim1Grad(self): + x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]] + output_shape = [20, 11, 3] + x_vals = [ + np.random.random_sample(x_shape).astype(np.float32) + for x_shape in x_shapes + ] + with self.test_session(): + with self.test_scope(): + xs = [constant_op.constant(x_val) for x_val in x_vals] + output = array_ops.concat_v2(xs, 1) + err = gradient_checker.compute_gradient_error(xs, x_shapes, output, + output_shape) + self.assertLess(err, 1e-4) + + def testConcatTuple(self): + c1 = np.random.rand(4, 4).astype(np.float32) + c2 = np.random.rand(4, 4).astype(np.float32) + with self.test_session(): + with self.test_scope(): + concat_list_t = array_ops.concat_v2([c1, c2], 0) + concat_tuple_t = array_ops.concat_v2((c1, c2), 0) + self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) + + def testConcatNoScalars(self): + with self.test_session(): + with self.test_scope(): + scalar = constant_op.constant(7) + dim = array_ops.placeholder(dtypes.int32) + with self.assertRaisesRegexp( + ValueError, r"Can't concatenate scalars \(use tf\.pack instead\)"): + array_ops.concat_v2([scalar, scalar, scalar], dim) + + +class ConcatOffsetTest(XLATestCase): + + def testBasic(self): + with self.test_session() as sess: + with self.test_scope(): + cdim = constant_op.constant(1, dtypes.int32) + s0 = constant_op.constant([2, 3, 5], dtypes.int32) + s1 = constant_op.constant([2, 7, 5], dtypes.int32) + s2 = constant_op.constant([2, 20, 5], dtypes.int32) + off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + ans = sess.run(off) + self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) + + +class PackTest(XLATestCase): + + def testBasic(self): + with self.test_session() as sess: + with self.test_scope(): + s0 = constant_op.constant([2, 3, 5], dtypes.int32) + s1 = constant_op.constant([2, 7, 5], dtypes.int32) + s2 = constant_op.constant([2, 20, 5], dtypes.int32) + packed = array_ops.stack([s0, s1, s2]) + ans = sess.run(packed) + self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) + + def testScalars(self): + with self.test_session() as sess: + with self.test_scope(): + s0 = constant_op.constant(2, dtypes.int32) + s1 = constant_op.constant(3, dtypes.int32) + s2 = constant_op.constant(5, dtypes.int32) + packed = array_ops.stack([s0, s1, s2]) + ans = sess.run(packed) + self.assertAllEqual(ans, [2, 3, 5]) + + def testEmpty(self): + with self.test_session() as sess: + with self.test_scope(): + s0 = constant_op.constant([[]], dtypes.int32) + s1 = constant_op.constant([[]], dtypes.int32) + s2 = constant_op.constant([[]], dtypes.int32) + packed = array_ops.stack([s0, s1, s2]) + ans = sess.run(packed) + self.assertAllEqual(ans, [[[]], [[]], [[]]]) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py new file mode 100644 index 0000000000..01cfbd9f7c --- /dev/null +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -0,0 +1,526 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Conv2D via the XLA JIT. + +The canned results in these tests are created by running each test using the +Tensorflow CPU device and saving the output. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class Conv2DTest(XLATestCase): + + def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected): + """Tests that tf.nn.conv2d produces the expected value. + + Args: + input_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_sizes: Filter tensor dimensions in + [kernel_rows, kernel_cols, input_depth, output_depth]. + stride: Stride. + padding: Padding type. + expected: Expected output. + """ + total_size_1 = np.prod(input_sizes) + total_size_2 = np.prod(filter_sizes) + x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) + x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes) + strides = [1, stride, stride, 1] + + with self.test_session() as sess: + with self.test_scope(): + t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) + out = nn_ops.conv2d( + t1, t2, strides=strides, padding=padding, data_format="NHWC") + value = sess.run(out, {t1: x1, t2: x2}) + self.assertArrayNear(expected, np.ravel(value), 1e-3) + + def testConv2D1x1Filter(self): + expected_output = [ + 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, + 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 + ] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[1, 1, 3, 3], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D2x2Filter(self): + expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D1x2Filter(self): + expected_output = [ + 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, + 936.0, 1029.0 + ] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[1, 2, 3, 3], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterStride2(self): + expected_output = [2271.0, 2367.0, 2463.0] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + stride=2, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterStride2Same(self): + expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + stride=2, + padding="SAME", + expected=expected_output) + + +class Conv2DBackpropInputTest(XLATestCase): + + def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride, + padding, expected): + """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output. + + Args: + input_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_sizes: Filter tensor dimensions in + [kernel_rows, kernel_cols, input_depth, output_depth]. + out_backprop_sizes: Output gradients tensor dimensions. + stride: Stride. + padding: Padding type. + expected: Expected output. + """ + total_size_1 = np.prod(filter_sizes) + total_size_2 = np.prod(out_backprop_sizes) + x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes) + x2 = np.arange( + 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes) + strides = [1, stride, stride, 1] + + with self.test_session() as sess: + with self.test_scope(): + t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) + out = gen_nn_ops.conv2d_backprop_input( + input_sizes=input_sizes, + filter=t1, + out_backprop=t2, + strides=strides, + padding=padding, + data_format="NHWC") + value = sess.run(out, {t1: x1, t2: x2}) + self.assertArrayNear(expected, np.ravel(value), 1e-3) + + def testConv2D1x1Filter(self): + expected_output = [ + 5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127, + 41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71, + 165, 259, 77, 179, 281, 83, 193, 303, 89, 207, 325, 95, 221, 347. + ] + self._VerifyValues( + input_sizes=[1, 4, 4, 3], + filter_sizes=[1, 1, 3, 2], + out_backprop_sizes=[1, 4, 4, 2], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterStride3Width5(self): + expected_output = [1, 2, 0, 2, 4] + self._VerifyValues( + input_sizes=[1, 1, 5, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterStride3Width6(self): + expected_output = [1, 2, 0, 2, 4, 0] + self._VerifyValues( + input_sizes=[1, 1, 6, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterStride3Width7(self): + expected_output = [1, 2, 0, 2, 4, 0, 0] + self._VerifyValues( + input_sizes=[1, 1, 7, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterC1Same(self): + expected_output = [1, 4, 7, 7, 23, 33] + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 2, 3, 1], + stride=1, + padding="SAME", + expected=expected_output) + + def testConv2D2x2Filter(self): + expected_output = [ + 14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604, + 437, 482, 527 + ] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 1, 2, 3], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterSame(self): + expected_output = [ + 14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217, + 1505, 1487, 1883, 2279 + ] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 2, 3, 3], + stride=1, + padding="SAME", + expected=expected_output) + + def testConv2D1x2Filter(self): + expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12] + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 3, 2, 1], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterSame(self): + expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25] + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 3, 3, 1], + stride=1, + padding="SAME", + expected=expected_output) + + def testConv2D2x2FilterStride2(self): + expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12] + self._VerifyValues( + input_sizes=[1, 3, 5, 1], + filter_sizes=[1, 3, 1, 1], + out_backprop_sizes=[1, 2, 2, 1], + stride=2, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterStride2Same(self): + expected_output = [1, 2, 2, 3, 4, 6] + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=2, + padding="SAME", + expected=expected_output) + + +class Conv2DBackpropFilterTest(XLATestCase): + + def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride, + padding, expected): + """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output. + + Args: + input_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_sizes: Filter tensor dimensions in + [kernel_rows, kernel_cols, input_depth, output_depth]. + out_backprop_sizes: Output gradients tensor dimensions. + stride: Stride. + padding: Padding type. + expected: Expected output. + """ + + total_size_1 = np.prod(input_sizes) + total_size_2 = np.prod(out_backprop_sizes) + x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) + x2 = np.arange( + 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes) + strides = [1, stride, stride, 1] + + with self.test_session() as sess: + with self.test_scope(): + t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) + tensor = gen_nn_ops.conv2d_backprop_filter( + input=t1, + filter_sizes=filter_sizes, + out_backprop=t2, + strides=strides, + padding=padding, + data_format="NHWC") + + value = sess.run(tensor, {t1: x1, t2: x2}) + self.assertArrayNear(expected, np.ravel(value), 1e-5) + + def testConv2D1x1Filter(self): + expected_output = [8056, 8432, 8312, 8704, 8568, 8976] + self._VerifyValues( + input_sizes=[1, 4, 4, 3], + filter_sizes=[1, 1, 3, 2], + out_backprop_sizes=[1, 4, 4, 2], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D1x2Filter(self): + expected_output = [120, 141] + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 3, 2, 1], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterDepth1(self): + expected_output = [5, 8, 14, 17] + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D2x2Filter(self): + expected_output = [ + 17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72, + 62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87, + 120, 153 + ] + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 1, 2, 3], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterStride3Width5(self): + expected_output = [9, 12] + self._VerifyValues( + input_sizes=[1, 1, 5, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterStride3Width6(self): + expected_output = [9, 12] + self._VerifyValues( + input_sizes=[1, 1, 6, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv2D1x2FilterStride3Width7(self): + expected_output = [9, 12] + self._VerifyValues( + input_sizes=[1, 1, 7, 1], + filter_sizes=[1, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv2D1x3Filter(self): + expected_output = [5, 8, 11] + self._VerifyValues( + input_sizes=[1, 1, 4, 1], + filter_sizes=[1, 3, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=1, + padding="VALID", + expected=expected_output) + + def testConv2D1x3FilterSame(self): + expected_output = [20, 30, 20] + self._VerifyValues( + input_sizes=[1, 1, 4, 1], + filter_sizes=[1, 3, 1, 1], + out_backprop_sizes=[1, 1, 4, 1], + stride=1, + padding="SAME", + expected=expected_output) + + def testConv2D1x3FilterSameOutbackprop2(self): + expected_output = [7, 10, 3] + self._VerifyValues( + input_sizes=[1, 1, 4, 1], + filter_sizes=[1, 3, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=2, + padding="SAME", + expected=expected_output) + + def testConv2D2x2FilterC1Same(self): + expected_output = [91, 58, 32, 17] + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 2, 3, 1], + stride=1, + padding="SAME", + expected=expected_output) + + def testConv2D2x2FilterStride2(self): + expected_output = [92, 102, 112] + self._VerifyValues( + input_sizes=[1, 3, 5, 1], + filter_sizes=[1, 3, 1, 1], + out_backprop_sizes=[1, 2, 2, 1], + stride=2, + padding="VALID", + expected=expected_output) + + def testConv2D2x2FilterStride2Same(self): + expected_output = [7, 2, 16, 5] + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 2, 1], + stride=2, + padding="SAME", + expected=expected_output) + + +class DepthwiseConv2DTest(XLATestCase): + + CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" + + def ConfigsToTest(self): + input_sizes = [[4, 35, 35, 2], [4, 147, 147, 2], [3, 299, 299, 3], + [5, 183, 183, 1]] + filter_sizes = [[5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]] + strides = [1, 3, 2, 2] + # pylint: disable=invalid-name + VALID = "VALID" + SAME = "SAME" + # pylint: enable=invalid-name + paddings = [SAME, VALID, SAME, SAME, SAME] + for i, f, s, p in zip(input_sizes, filter_sizes, strides, paddings): + yield i, f, s, p + + def _VerifyValues(self, input_size, filter_size, stride, padding): + imag = np.random.rand(*input_size).astype(np.float32) + filt = np.random.rand(*filter_size).astype(np.float32) + strides = [1, stride, stride, 1] + + with self.test_session(): + with self.test_scope(): + imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size) + filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size) + feed_dict = {imag_ph: imag, filt_ph: filt} + xla_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides, + padding).eval(feed_dict=feed_dict) + + with self.test_session(): + with ops.device(self.CPU_DEVICE): + imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size) + filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size) + feed_dict = {imag_ph: imag, filt_ph: filt} + cpu_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides, + padding).eval(feed_dict=feed_dict) + + self.assertAllClose(xla_out, cpu_out) + + # This is disabled because we need a mechanism to set command-line flags, + # i.e. an implementation of SetCommandLineOption() below. + # + # def _VerifyDummy(self, input_size, filter_size, stride, padding): + # imag = np.random.rand(*input_size).astype(np.float32) + # filt = np.random.rand(*filter_size).astype(np.float32) + # strides = [1, stride, stride, 1] + # + # with self.test_session(): + # with self.test_scope(): + # imag_ph = tf.placeholder(tf.float32, shape=input_size) + # filt_ph = tf.placeholder(tf.float32, shape=filter_size) + # feed_dict = {imag_ph: imag, filt_ph: filt} + # SetCommandLineOption( + # "tf_tla_depthwise_conv2d_custom_func", + # "DummyDepthwiseConv2dKernel") + # xla_out = tf.nn.depthwise_conv2d( + # imag_ph, filt_ph, strides, padding).eval(feed_dict=feed_dict) + # SetCommandLineOption( + # "tf_tla_depthwise_conv2d_custom_func", "") + # + # expected = np.array(range(np.ravel(xla_out).shape[0]), dtype=np.float32) + # self.assertAllClose(np.ravel(xla_out), expected) + + def testBasic(self): + for i, f, s, p in self.ConfigsToTest(): + self._VerifyValues(i, f, s, p) + + # Test disabled until _VerifyDummy(), above can be implemented. + # def testCustomFunc(self): + # if self.has_custom_call: + # for i, f, s, p in self.ConfigsToTest(): + # self._VerifyDummy(i, f, s, p) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc b/tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc new file mode 100644 index 0000000000..97b71c0228 --- /dev/null +++ b/tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/types.h" + +using tensorflow::int64; + +// A dummy implementation that fills the output with 0, 1, 2,... +// to test the custom call implementation of DepthwiseConv2dNative op. +// TODO(keveman): Test this after adding a real implementation for the kernel. +extern "C" void DummyDepthwiseConv2dKernel(float* output, void** inputs) { + const int64* output_size = reinterpret_cast(inputs[4]); + const int64 total_size = + output_size[0] * output_size[1] * output_size[2] * output_size[3]; + for (int64 i = 0; i < total_size; ++i) { + *(output + i) = i; + } +} diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py new file mode 100644 index 0000000000..c109c27abe --- /dev/null +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.dynamic_stitch.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import googletest + + +class DynamicStitchTest(XLATestCase): + + def _AssertDynamicStitchResultIs(self, indices, data, expected): + with self.test_session() as session: + index_placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices + ] + data_placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data + ] + with self.test_scope(): + output = data_flow_ops.dynamic_stitch(index_placeholders, + data_placeholders) + + feed_dict = {} + for placeholder, value in zip(index_placeholders, indices): + feed_dict[placeholder] = value + for placeholder, value in zip(data_placeholders, data): + feed_dict[placeholder] = value + result = session.run(output, feed_dict=feed_dict) + self.assertAllClose(expected, result, rtol=1e-3) + + def testSimpleEmpty(self): + idx1 = np.array([0, 2], dtype=np.int32) + idx2 = np.array([[1], [3]], dtype=np.int32) + val1 = np.array([[], []], dtype=np.int32) + val2 = np.array([[[]], [[]]], dtype=np.int32) + self._AssertDynamicStitchResultIs( + [idx1, idx2], [val1, val2], + expected=np.array([[], [], [], []], np.int32)) + + def testSimple1D(self): + val1 = np.array([0, 4, 7], dtype=np.int32) + val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32) + val3 = np.array([0, 40, 70], dtype=np.float32) + val4 = np.array([10, 60, 20, 30, 50], dtype=np.float32) + expected = np.array([0, 10, 20, 30, 40, 50, 60, 70], dtype=np.float32) + self._AssertDynamicStitchResultIs( + [val1, val2], [val3, val4], expected=expected) + + def testSimple2D(self): + val1 = np.array([0, 4, 7], dtype=np.int32) + val2 = np.array([1, 6], dtype=np.int32) + val3 = np.array([2, 3, 5], dtype=np.int32) + val4 = np.array([[0, 1], [40, 41], [70, 71]], dtype=np.float32) + val5 = np.array([[10, 11], [60, 61]], dtype=np.float32) + val6 = np.array([[20, 21], [30, 31], [50, 51]], dtype=np.float32) + expected = np.array( + [[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61], + [70, 71]], + dtype=np.float32) + self._AssertDynamicStitchResultIs( + [val1, val2, val3], [val4, val5, val6], expected=expected) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py new file mode 100644 index 0000000000..40cc7a5d60 --- /dev/null +++ b/tensorflow/compiler/tests/function_test.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for Tensorflow functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class FunctionTest(XLATestCase): + + def testFunction(self): + """Executes a simple TensorFlow function.""" + + def APlus2B(a, b): + return a + b * 2 + + aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) + bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) + expected = APlus2B(aval, bval) + + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.float32) + def Foo(a, b): + return APlus2B(a, b) + + a = constant_op.constant(aval, name="a") + b = constant_op.constant(bval, name="b") + with self.test_scope(): + call_f = Foo(a, b) + result = sess.run(call_f) + self.assertAllClose(result, expected, rtol=1e-3) + + def testNestedFunctions(self): + """Executes two nested TensorFlow functions.""" + + def TimesTwo(x): + return x * 2 + + def APlus2B(a, b): + return a + TimesTwo(b) + + aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) + bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) + expected = APlus2B(aval, bval) + + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.float32) + def Foo(a, b): + return APlus2B(a, b) + + a = constant_op.constant(aval, name="a") + b = constant_op.constant(bval, name="b") + with self.test_scope(): + call_g = Foo(a, b) + result = sess.run(call_g) + self.assertAllClose(result, expected, rtol=1e-3) + + def testFunctionMultipleRetvals(self): + """Executes a function with multiple return values.""" + + # This function will run on the XLA device + def Func(a, b): + return a + b, a - b + + aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) + bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) + expected = Func(aval, bval) + + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.float32) + def Foo(a, b): + return Func(a, b) + + a = constant_op.constant(aval, name="a") + b = constant_op.constant(bval, name="b") + with self.test_scope(): + call_f = Foo(a, b) + result = sess.run(call_f) + self.assertAllClose(result, expected, rtol=1e-3) + + def testFunctionsNoInline(self): + + @function.Defun(dtypes.float32, noinline=True) + def TimesTwo(x): + return x * 2 + + @function.Defun(dtypes.float32, dtypes.float32) + def APlus2B(a, b): + return a + TimesTwo(b) + + aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) + bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) + expected = aval + bval * 2 + + with self.test_session() as sess: + with self.test_scope(): + a = array_ops.placeholder(dtypes.float32, name="a") + b = array_ops.placeholder(dtypes.float32, name="b") + call = APlus2B(a, b) + result = sess.run(call, {a: aval, b: bval}) + self.assertAllClose(result, expected, rtol=1e-3) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py new file mode 100644 index 0000000000..8a568d6d58 --- /dev/null +++ b/tensorflow/compiler/tests/jit_test.py @@ -0,0 +1,459 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for JIT compilation on the CPU and GPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.compiler import jit +from tensorflow.core.framework import function_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + +jit_scope = jit.experimental_jit_scope + + +def CompiledKernel(fn, *inputs, **kwargs): + """Execute 'fn' as a compiled XLA kernel, with 'inputs'.""" + name = kwargs.pop("name", None) + noinline = kwargs.pop("noinline", None) + + @function.Defun(func_name=name, noinline=noinline, compiled=True) + def Compiled(*args): + return fn(*args) + + return Compiled(*inputs) + + +def RunMetadataLabels(run_metadata): + """Returns all labels in run_metadata.""" + labels = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + labels.append(node_stats.timeline_label) + return labels + + +def InLabels(labels, substr): + """Returns true iff one of the labels contains substr.""" + return any([substr in x for x in labels]) + + +def MetadataHasXlaLaunch(run_metadata): + """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + + # TODO(phawkins): find a less hacky way to test whether a kernel ran. + return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + + +class JitLaunchTest(test.TestCase): + + # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel. + # Verifies that the outputs match and that XLA was invoked. 'fn' must take + # the same number of tensors as arguments that are in 'args', and must return + # a tuple of output tensors. + # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node + # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # constant-folded away, so the check is optional. + def _compare(self, fn, args, require_kernel_launch=True, noinline=None): + with session_lib.Session() as sess: + placeholders = [] + feeds = {} + for arg in args: + placeholder = array_ops.placeholder( + dtypes.as_dtype(arg.dtype), list(arg.shape)) + placeholders.append(placeholder) + feeds[placeholder] = arg + + compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline) + direct_op = fn(*placeholders) + + run_metadata = config_pb2.RunMetadata() + compiled = sess.run(compiled_op, + feeds, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + print("Compiled Result {}".format(compiled)) + + if require_kernel_launch: + self.assert_(MetadataHasXlaLaunch(run_metadata)) + + direct = sess.run(direct_op, feeds) + print("Direct Result {}".format(direct)) + + if (isinstance(compiled, (tuple, list)) and + (isinstance(direct, (tuple, list)))): + for (x, y) in zip(compiled, direct): + self.assertAllClose(x, y, rtol=1e-1) + else: + self.assertAllClose(compiled, direct) + + def testNoOutputs(self): + with session_lib.Session() as sess: + # Build a function with a single Const node, whose output is ignored. + fdef = function_pb2.FunctionDef() + fdef.signature.name = "KernelWithNoOutputs" + node = node_def_pb2.NodeDef() + node.op = "Const" + node.name = "ignored" + node.attr["dtype"].type = dtypes.int32.as_datatype_enum + tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[]) + node.attr["value"].tensor.CopyFrom(tensor) + fdef.node_def.extend([node]) + + # Check that calling the result as a compiled kernel doesn't crash. + @function.Defun(compiled=True) + def KernelWithNoOutputs(): + return constant_op.constant(100) + + # Hack to override the definition. By accessing .definition, we + # force the _DefinedFunction initialized internally. Then, we + # replace it's internal FunctionDef proto. We do this hack here + # because one typically can't construct KernelWithNoOutputs + # function via Defun decorator directly. + _ = KernelWithNoOutputs.definition + foo = KernelWithNoOutputs + foo._definition = fdef + call = KernelWithNoOutputs() + sess.run(call, {}) + + def testAliasing(self): + """Regression test for compiled functions that return an aliased buffer. + + XLA returns aliased buffers if outputs are identical. Tests that + we handle that case. + """ + + def AddOnceReturnTwice(x): + y = math_ops.add(x, x) + return y, y + + # Exercises compling a function (say, Foo) which calls another + # function (say, Bar) which is not inlined. When the compiler compiles + # Foo, it needs to symbolic execute Bar correctly regardless whether + # Bar is inlined or not. + # + # Tests compiled=True and noinline=True. + self._compare( + AddOnceReturnTwice, [np.array( + [[[0.5, -1.0]]], dtype=np.float32)], + noinline=True) + # Tests compiled=True and noinline=False. + self._compare( + AddOnceReturnTwice, [np.array( + [[[0.5, -1.0]]], dtype=np.float32)], + noinline=False) + + def testOneConstOutput(self): + """Test consisting of a single constant return value.""" + + def OneConstOutput(): + return constant_op.constant([-3, 44, 99]) + + self._compare(OneConstOutput, [], require_kernel_launch=False) + + def testConstZeroElementOutput(self): + """Test consisting of a constant zero element return value.""" + + def ConstZeroElementOutput(): + return array_ops.fill([7, 0], 3.0) + + self._compare(ConstZeroElementOutput, [], require_kernel_launch=False) + + def testSomeConstOutputs(self): + """Test kernels that return a mixture of const and non-const outputs.""" + + def SomeConstOutputs(x): + return constant_op.constant( + [-2, 7]), array_ops.identity(x), constant_op.constant(3.5) + + self._compare( + SomeConstOutputs, [np.array( + [[1, 2, 3], [4, 5, 6]], dtype=np.float32)]) + + def testInt32Input(self): + """Test an int32-typed input. + + On a GPU, int32 tensors will be placed in host memory. + """ + + def AddToSelf(x): + return math_ops.add(x, x) + + self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)]) + + def testMandatoryConstantInput(self): + """Tests an operator that has a mandatory-constant shape input.""" + + def FillWithFloat(x): + return array_ops.fill(x, 9.5) + + self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)]) + + def testMnistForwardFunc(self): + """Compute inference function from MNIST beginners tutorial.""" + batch_size = 16 + image_size = 28 * 28 + num_classes = 10 + + # Define a TensorFlow function to compute the forward pass. + def MnistForward(w, b, x): + return nn_ops.softmax(math_ops.matmul(x, w) + b) + + w = np.random.random_sample((image_size, num_classes)).astype(np.float32) + b = np.random.random_sample((num_classes)).astype(np.float32) + x = np.random.random_sample((batch_size, image_size)).astype(np.float32) + self._compare(MnistForward, [w, b, x]) + + def testExplicitMarking(self): + """Test explicit marking of operators to compile.""" + batch_size = 16 + image_size = 28 * 28 + num_classes = 10 + + with ops.Graph().as_default(): + x = array_ops.placeholder(dtypes.float32) + w = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + with jit_scope(): + y1 = math_ops.matmul(x, w) + y2 = math_ops.add(y1, b) + with jit_scope(): + y = math_ops.square(y2) + + dw = np.random.random_sample((image_size, num_classes)).astype(np.float32) + db = np.random.random_sample((num_classes)).astype(np.float32) + dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) + with session_lib.Session() as sess: + run_metadata = config_pb2.RunMetadata() + output = sess.run(y, {x: dx, + w: dw, + b: db}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + # TODO(phawkins): really we would like to test that there were exactly + # two kernel launches. However, we have no reliable way to determine + # that. + self.assert_(MetadataHasXlaLaunch(run_metadata)) + + expected = np.square(np.dot(dx, dw) + db) + self.assertAllClose(expected, output, rtol=1e-1) + + +class XlaCompilationTest(test.TestCase): + """Tests for auto-compilation on CPU/GPU devices.""" + + def testReshape(self): + """Tests an operator with compile-time constant and non-constant inputs.""" + + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32) + y = array_ops.placeholder(dtypes.int32) + with jit_scope(): + # Reshape's first argument is non-constant in the JIT, but its second + # (shape) argument will be treated as a compile-time constant for + # each JIT compilation. + # We do not use a tf.const() argument since we want to ensure the + # shape is still a run-time argument to the JIT, and not + # statically known as part of the JIT compilation's input graph. + z = array_ops.reshape(x, y) + run_metadata = config_pb2.RunMetadata() + out = sess.run(z, + {x: np.array([1, 2, 3, 4, 5, 6], np.float32), + y: [-1, 3]}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) + + def testIgnoredArguments(self): + """Tests that JIT computations can ignore formal parameters.""" + + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.int32) + y = array_ops.placeholder(dtypes.int32) + with jit_scope(): + z = math_ops.add(x, x) + w = math_ops.add(y, y) + # Pulls 'w' into the same compilation via control dependencies. + with ops.control_dependencies([w]): + n = control_flow_ops.no_op() + with ops.control_dependencies([n]): + t = math_ops.add(z, z) + + run_metadata = config_pb2.RunMetadata() + out = sess.run(t, {x: np.int32(7), + y: np.int32(404)}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assertAllClose(28, out) + + def testLoops(self): + """Tests that compilation accepts computations containing loops.""" + + with self.test_session() as session: + x = array_ops.placeholder(dtypes.float32) + with jit_scope(): + c = lambda i, _: math_ops.less(i, 5) + b = lambda i, x: (i + 1, x * 2.0 + 1.0) + _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) + + run_metadata = config_pb2.RunMetadata() + result = session.run(y, {x: np.float32(2)}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assertAllClose(result, np.float32(95), rtol=1e-1) + + def testCond(self): + """Tests that compilation handles switch operators.""" + + with self.test_session() as session: + x = array_ops.placeholder(dtypes.float32) + y = array_ops.placeholder(dtypes.float32) + c = array_ops.placeholder(dtypes.bool) + with jit_scope(): + z = x + 1.0 + w = control_flow_ops.cond(c, lambda: z, lambda: y) + t = math_ops.add(z, w) + + # If JIT compilation chooses to cluster z and t, then execution will + # deadlock. + + run_metadata = config_pb2.RunMetadata() + result = session.run(t, {x: np.float32(2), + y: np.float32(4), + c: True}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assertAllClose(result, np.float32(6), rtol=1e-1) + + def testNestedFunction(self): + g = ops.Graph() + with g.as_default(): + + @function.Defun(compiled=True) + def Bar(x, y): + return x + 2 * y + + @function.Defun(compiled=True) + def Foo(x): + return Bar(x * x, x * x * x) + + @function.Defun() + def Entry(x): + return Foo(x) + + inp = array_ops.placeholder(dtypes.float32) + out = Entry(inp) + + with self.test_session(graph=g, use_gpu=True) as sess: + run_metadata = config_pb2.RunMetadata() + val = sess.run(out, + feed_dict={inp: [2., 10.]}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertAllClose(val, [20., 2100.]) + + def testLoopDeadlock(self): + """Regression test for bug that caused deadlocks in graphs with loops.""" + + with self.test_session() as session: + x = array_ops.placeholder(dtypes.float32) + with jit_scope(): + y = x + 1.0 + c = lambda i, _x, _y: math_ops.less(i, 5) + b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0) + _, _, w = control_flow_ops.while_loop(c, b, + (constant_op.constant(0), y, x)) + u = w + y + result = session.run(u, {x: np.float32(2)}) + self.assertAllClose(result, np.float32(63), rtol=1e-1) + + def testGradient(self): + """Tests that the backprop function is properly compiled.""" + + def _Run(compiled): + + @function.Defun(compiled=compiled) + def Forward(x): + return math_ops.log(x) + + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32) + y = Forward(x) + dx, = gradients_impl.gradients(y, [x], 1.0) + + cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L1, + do_function_inlining=True))) + with session_lib.Session(graph=g, config=cfg) as sess: + run_metadata = config_pb2.RunMetadata() + dx_val = sess.run(dx, + feed_dict={x: 100.}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertAllClose(dx_val, 0.01) + return RunMetadataLabels(run_metadata) + + # SymGrad[f=log(x)](x, dy) = 1/x * dy + # + # Note: we don't need to compute log(x) for dx due to graph pruning. + + # Do not compile the backprop. We should see one Reciprocal and one Mul. + labels = _Run(compiled=False) + self.assertFalse(InLabels(labels, "Log")) + self.assertTrue(InLabels(labels, "Reciprocal")) + self.assertTrue(InLabels(labels, "Mul")) + self.assertFalse(InLabels(labels, "_XlaLaunch")) + + # Compile the backprop. One _XlaLaunch. + labels = _Run(compiled=True) + self.assertFalse(InLabels(labels, "Log")) + self.assertFalse(InLabels(labels, "Reciprocal")) + self.assertFalse(InLabels(labels, "Mul")) + self.assertTrue(InLabels(labels, "_XlaLaunch")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py new file mode 100644 index 0000000000..5d8d89224d --- /dev/null +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -0,0 +1,129 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Local Response Normalization ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import googletest + +CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" + + +# Local response normalization tests. The forward tests are copied from +# tensorflow/python/kernel_tests/lrn_op_test.py +class LRNTest(XLATestCase): + + def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0, + beta=0.5): + """Compute expected result.""" + output = copy.deepcopy(input_image) + batch_size = input_image.shape[0] + rows = input_image.shape[1] + cols = input_image.shape[2] + depth = input_image.shape[3] + for b in range(batch_size): + for r in range(rows): + for c in range(cols): + for d in range(depth): + begin = max(0, d - lrn_depth_radius) + end = min(depth, d + lrn_depth_radius + 1) + patch = input_image[b, r, c, begin:end] + output[b, r, c, d] /= ( + np.power(bias + alpha * np.sum(patch * patch), beta)) + return output + + def _RunAndVerify(self, dtype): + with self.test_session(): + # random shape + shape = np.random.randint(1, 16, size=4) + # Make depth at least 2 to make it meaningful + shape[3] += 1 + p = array_ops.placeholder(dtype, shape=shape) + # random depth_radius, bias, alpha, beta + lrn_depth_radius = np.random.randint(1, shape[3]) + bias = 1.0 + np.random.rand() + alpha = 2.0 * np.random.rand() + beta = 2.0 * np.random.rand() + with self.test_scope(): + lrn_t = nn.local_response_normalization( + p, + name="lrn", + depth_radius=lrn_depth_radius, + bias=bias, + alpha=alpha, + beta=beta) + params = {p: np.random.rand(*shape).astype("f")} + result = lrn_t.eval(feed_dict=params) + expected = self._LRN( + params[p], + lrn_depth_radius=lrn_depth_radius, + bias=bias, + alpha=alpha, + beta=beta) + err = np.amax(np.abs(result - expected)) + print("LRN error for bias ", bias, "alpha ", alpha, " beta ", beta, " is ", + err) + if dtype == dtypes.float32: + self.assertTrue(err < 1e-4) + else: + self.assertTrue(err < 1e-2) + self.assertShapeEqual(expected, lrn_t) + + def testCompute(self): + for _ in range(2): + self._RunAndVerify(dtypes.float32) + + def testLrnGrad(self): + # Test for LRNGrad that compares against the CPU implementation. + shape = [1, 2, 3, 4] + total_size = np.prod(shape) + in_image_vals = np.arange(1, total_size + 1, dtype=np.float32) + out_image_vals = np.arange(1, total_size + 1, dtype=np.float32) + out_grads_vals = np.arange(1, total_size + 1, dtype=np.float32) + depth_radius = np.random.randint(1, shape[3]) + bias = 1.0 + np.random.rand() + alpha = 1.0 * np.random.rand() + beta = 1.0 * np.random.rand() + + with self.test_session(): + in_image = constant_op.constant(in_image_vals, shape=shape) + out_image = constant_op.constant(out_image_vals, shape=shape) + out_grads = constant_op.constant(out_grads_vals, shape=shape) + with ops.device(CPU_DEVICE): + expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) + with self.test_scope(): + actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) + expected_val = expected.eval() + actual_val = actual.eval() + self.assertAllClose(actual_val, expected_val, rtol=1e-3) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py new file mode 100644 index 0000000000..18166f51bf --- /dev/null +++ b/tensorflow/compiler/tests/lstm.py @@ -0,0 +1,158 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple LSTM layer with benchmarks. + +This sets up a simple LSTM (Long Short Term Memory) layer, unrolled to a fixed +length sequence. The only deviation from standard LSTM cells is that +activations are clipped, inspired by the GNMT machine translation model. +The GNMT paper has more details: https://arxiv.org/abs/1609.08144 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables + + +def Clip(x): + """Clips x to the range [-1., 1.].""" + return math_ops.maximum(math_ops.minimum(x, 1.), -1.) + + +def LSTMCellWeightsShape(num_inputs, num_nodes): + """Returns the shape of the weights for a single LSTM cell.""" + # Dimension 0 accounts for combining x with the previous m state. + # Dimension 1 accounts for the in value and the (in, forget, out) gates. + return [num_inputs + num_nodes, 4 * num_nodes] + + +def LSTMCell(weights, m_prev, c_prev, x, pad): + """Unrolls a single LSTM cell with clipped activations forward by one step. + + Args: + weights: Weight matrix with shape LSTMCellWeightsShape. + m_prev: Previous m states with shape [batch_size, num_nodes]. + c_prev: Previous c states with shape [batch_size, num_nodes]. + x: Input with shape [batch_size, num_inputs]. + pad: Padding with shape [batch_size, 1]. Each padding value is either + 0 or 1, where 1 indicates padding; i.e. the input is shorter than the + sequence length, and the (m, c) states should simply be passed through + from the previous states. + Returns: + The next (m, c) states, each with shape [batch_size, num_nodes]. + """ + # Apply weights to the input and previous hidden state. + # The matmul here is the "big" operation. + xm = array_ops.concat_v2([x, m_prev], 1) + xmw = math_ops.matmul(xm, weights) + + # Element-wise ops for the standard LSTM cell, with clipped activations. + # XLA can fuse these operations into a single loop. + in_value, in_gate, forget_gate, out_gate = array_ops.split( + value=xmw, num_or_size_splits=4, axis=1) + in_value = math_ops.tanh(in_value) + in_gate = math_ops.sigmoid(in_gate) + forget_gate = math_ops.sigmoid(forget_gate) + out_gate = math_ops.sigmoid(out_gate) + c_next = Clip(Clip(forget_gate * c_prev) + Clip(in_gate * in_value)) + m_next = Clip(out_gate * c_next) + + # Account for padding. + c_next = c_prev * pad + c_next * (1.0 - pad) + m_next = m_prev * pad + m_next * (1.0 - pad) + + return m_next, c_next + + +def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq): + """Unrolls a layer of LSTM cells forward by the sequence length. + + The sequence length is determined by the length of x_seq and pad_seq, which + must be the same. + + Args: + cell_name: Base name of each cell. + weights: Weight matrix with shape LSTMCellWeightsShape. + m: Initial m states with shape [batch_size, num_nodes]. + c: Initial c states with shape [batch_size, num_nodes]. + x_seq: List of inputs, each with shape [batch_size, num_inputs]. + The length of the list is the sequence length. + pad_seq: List of paddings, each with shape [batch_size, 1]. + The length of the list is the sequence length. + Each padding value is either 0 or 1, where 1 indicates padding; + i.e. the input is shorter than the sequence length. + Returns: + List of per-sequence-step outputs, each with shape [batch_size, num_nodes]. + Raises: + ValueError: If len(x_seq) != len(pad_seq). + """ + if len(x_seq) != len(pad_seq): + raise ValueError('length of x_seq(%d) != pad_seq(%d)' % + (len(x_seq), len(pad_seq))) + out_seq = [] + for seq in range(len(x_seq)): + with ops.name_scope('%s_%d' % (cell_name, seq)): + m, c = LSTMCell(weights, m, c, x_seq[seq], pad_seq[seq]) + out_seq.append(array_ops.identity(m, name='out')) + return out_seq + + +def RandomVar(shape, name=None): + """Returns a variable of the given shape initialized to random values.""" + return variables.Variable( + random_ops.random_uniform(shape), dtype=dtypes.float32, name=name) + + +def RandomInputs(batch_size, seq_length, num_inputs): + """Returns randomly initialized (x_seq, pad_seq) sequences.""" + x_seq = [] + pad_seq = [] + with ops.name_scope('inputs'): + for seq in range(seq_length): + x_seq.append(RandomVar([batch_size, num_inputs], name='x_seq_%d' % seq)) + # Real padding values are always a sequence of 0 followed by a + # sequence of 1, but random values are fine for benchmarking. + pad_seq.append(RandomVar([batch_size, 1], name='pad_seq_%d' % seq)) + return x_seq, pad_seq + + +def BuildLSTMLayer(batch_size, seq_length, num_inputs, num_nodes): + """Builds a single LSTM layer with random weights and inputs. + + Args: + batch_size: Inputs are fed in batches of this size. + seq_length: The sequence length to unroll the LSTM layer. + num_inputs: Dimension of inputs that are fed into each LSTM cell. + num_nodes: The number of nodes in each LSTM cell. + + Returns: + (out_seq, weights) pair. The out_seq is a list of per-sequence-step + outputs, each with shape [batch_size, num_nodes]. The weights are a list of + weight variables that may be trained. + """ + weights = RandomVar( + LSTMCellWeightsShape(num_inputs, num_nodes), name='weights') + m = array_ops.zeros([batch_size, num_nodes], name='init_m') + c = array_ops.zeros([batch_size, num_nodes], name='init_c') + x_seq, pad_seq = RandomInputs(batch_size, seq_length, num_inputs) + + out_seq = LSTMLayer('lstm', weights, m, c, x_seq, pad_seq) + return out_seq, [weights] diff --git a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt new file mode 100644 index 0000000000..c46e65f71a --- /dev/null +++ b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt @@ -0,0 +1,20 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} } +feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} } +feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} } +feed{ id{node_name:"inputs/x_seq_3/read"} shape{dim{size:128}dim{size:1024}} } +feed{ id{node_name:"inputs/x_seq_4/read"} shape{dim{size:128}dim{size:1024}} } +feed{ id{node_name:"inputs/pad_seq_0/read"} shape{dim{size:128}dim{size:1}} } +feed{ id{node_name:"inputs/pad_seq_1/read"} shape{dim{size:128}dim{size:1}} } +feed{ id{node_name:"inputs/pad_seq_2/read"} shape{dim{size:128}dim{size:1}} } +feed{ id{node_name:"inputs/pad_seq_3/read"} shape{dim{size:128}dim{size:1}} } +feed{ id{node_name:"inputs/pad_seq_4/read"} shape{dim{size:128}dim{size:1}} } +feed{ id{node_name:"weights/read"} shape{dim{size:2048}dim{size:4096}} } +feed{ id{node_name:"init_c"} shape{dim{size:128}dim{size:1024}} } +feed{ id{node_name:"init_m"} shape{dim{size:128}dim{size:1024}} } + +fetch{ id{node_name:"lstm_0/out"} } +fetch{ id{node_name:"lstm_1/out"} } +fetch{ id{node_name:"lstm_2/out"} } +fetch{ id{node_name:"lstm_3/out"} } +fetch{ id{node_name:"lstm_4/out"} } diff --git a/tensorflow/compiler/tests/lstm_layer_inference.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.pbtxt new file mode 100644 index 0000000000..649f4a8ff1 --- /dev/null +++ b/tensorflow/compiler/tests/lstm_layer_inference.pbtxt @@ -0,0 +1,5828 @@ +# Generated by running lstm_test, setting --dump_graph_dir. + +node { + name: "random_uniform/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\010\000\000\000\020\000\000" + } + } + } +} +node { + name: "random_uniform/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "random_uniform/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "random_uniform/RandomUniform" + op: "RandomUniform" + input: "random_uniform/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "random_uniform/sub" + op: "Sub" + input: "random_uniform/max" + input: "random_uniform/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "random_uniform/mul" + op: "Mul" + input: "random_uniform/RandomUniform" + input: "random_uniform/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "random_uniform" + op: "Add" + input: "random_uniform/mul" + input: "random_uniform/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "weights" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2048 + } + dim { + size: 4096 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "weights/Assign" + op: "Assign" + input: "weights" + input: "random_uniform" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@weights" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "weights/read" + op: "Identity" + input: "weights" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@weights" + } + } + } +} +node { + name: "init_m" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + float_val: 0.0 + } + } + } +} +node { + name: "init_c" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\004\000\000" + } + } + } +} +node { + name: "inputs/random_uniform/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform/sub" + op: "Sub" + input: "inputs/random_uniform/max" + input: "inputs/random_uniform/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform/mul" + op: "Mul" + input: "inputs/random_uniform/RandomUniform" + input: "inputs/random_uniform/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform" + op: "Add" + input: "inputs/random_uniform/mul" + input: "inputs/random_uniform/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/x_seq_0" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/x_seq_0/Assign" + op: "Assign" + input: "inputs/x_seq_0" + input: "inputs/random_uniform" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_0" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/x_seq_0/read" + op: "Identity" + input: "inputs/x_seq_0" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_0" + } + } + } +} +node { + name: "inputs/random_uniform_1/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_1/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_1/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_1/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_1/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_1/sub" + op: "Sub" + input: "inputs/random_uniform_1/max" + input: "inputs/random_uniform_1/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_1/mul" + op: "Mul" + input: "inputs/random_uniform_1/RandomUniform" + input: "inputs/random_uniform_1/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_1" + op: "Add" + input: "inputs/random_uniform_1/mul" + input: "inputs/random_uniform_1/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/pad_seq_0" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/pad_seq_0/Assign" + op: "Assign" + input: "inputs/pad_seq_0" + input: "inputs/random_uniform_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_0" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/pad_seq_0/read" + op: "Identity" + input: "inputs/pad_seq_0" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_0" + } + } + } +} +node { + name: "inputs/random_uniform_2/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\004\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_2/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_2/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_2/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_2/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_2/sub" + op: "Sub" + input: "inputs/random_uniform_2/max" + input: "inputs/random_uniform_2/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_2/mul" + op: "Mul" + input: "inputs/random_uniform_2/RandomUniform" + input: "inputs/random_uniform_2/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_2" + op: "Add" + input: "inputs/random_uniform_2/mul" + input: "inputs/random_uniform_2/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/x_seq_1" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/x_seq_1/Assign" + op: "Assign" + input: "inputs/x_seq_1" + input: "inputs/random_uniform_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_1" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/x_seq_1/read" + op: "Identity" + input: "inputs/x_seq_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_1" + } + } + } +} +node { + name: "inputs/random_uniform_3/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_3/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_3/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_3/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_3/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_3/sub" + op: "Sub" + input: "inputs/random_uniform_3/max" + input: "inputs/random_uniform_3/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_3/mul" + op: "Mul" + input: "inputs/random_uniform_3/RandomUniform" + input: "inputs/random_uniform_3/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_3" + op: "Add" + input: "inputs/random_uniform_3/mul" + input: "inputs/random_uniform_3/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/pad_seq_1" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/pad_seq_1/Assign" + op: "Assign" + input: "inputs/pad_seq_1" + input: "inputs/random_uniform_3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_1" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/pad_seq_1/read" + op: "Identity" + input: "inputs/pad_seq_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_1" + } + } + } +} +node { + name: "inputs/random_uniform_4/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\004\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_4/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_4/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_4/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_4/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_4/sub" + op: "Sub" + input: "inputs/random_uniform_4/max" + input: "inputs/random_uniform_4/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_4/mul" + op: "Mul" + input: "inputs/random_uniform_4/RandomUniform" + input: "inputs/random_uniform_4/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_4" + op: "Add" + input: "inputs/random_uniform_4/mul" + input: "inputs/random_uniform_4/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/x_seq_2" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/x_seq_2/Assign" + op: "Assign" + input: "inputs/x_seq_2" + input: "inputs/random_uniform_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_2" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/x_seq_2/read" + op: "Identity" + input: "inputs/x_seq_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_2" + } + } + } +} +node { + name: "inputs/random_uniform_5/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_5/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_5/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_5/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_5/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_5/sub" + op: "Sub" + input: "inputs/random_uniform_5/max" + input: "inputs/random_uniform_5/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_5/mul" + op: "Mul" + input: "inputs/random_uniform_5/RandomUniform" + input: "inputs/random_uniform_5/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_5" + op: "Add" + input: "inputs/random_uniform_5/mul" + input: "inputs/random_uniform_5/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/pad_seq_2" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/pad_seq_2/Assign" + op: "Assign" + input: "inputs/pad_seq_2" + input: "inputs/random_uniform_5" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_2" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/pad_seq_2/read" + op: "Identity" + input: "inputs/pad_seq_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_2" + } + } + } +} +node { + name: "inputs/random_uniform_6/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\004\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_6/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_6/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_6/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_6/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_6/sub" + op: "Sub" + input: "inputs/random_uniform_6/max" + input: "inputs/random_uniform_6/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_6/mul" + op: "Mul" + input: "inputs/random_uniform_6/RandomUniform" + input: "inputs/random_uniform_6/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_6" + op: "Add" + input: "inputs/random_uniform_6/mul" + input: "inputs/random_uniform_6/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/x_seq_3" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/x_seq_3/Assign" + op: "Assign" + input: "inputs/x_seq_3" + input: "inputs/random_uniform_6" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_3" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/x_seq_3/read" + op: "Identity" + input: "inputs/x_seq_3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_3" + } + } + } +} +node { + name: "inputs/random_uniform_7/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_7/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_7/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_7/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_7/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_7/sub" + op: "Sub" + input: "inputs/random_uniform_7/max" + input: "inputs/random_uniform_7/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_7/mul" + op: "Mul" + input: "inputs/random_uniform_7/RandomUniform" + input: "inputs/random_uniform_7/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_7" + op: "Add" + input: "inputs/random_uniform_7/mul" + input: "inputs/random_uniform_7/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/pad_seq_3" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/pad_seq_3/Assign" + op: "Assign" + input: "inputs/pad_seq_3" + input: "inputs/random_uniform_7" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_3" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/pad_seq_3/read" + op: "Identity" + input: "inputs/pad_seq_3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_3" + } + } + } +} +node { + name: "inputs/random_uniform_8/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\004\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_8/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_8/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_8/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_8/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_8/sub" + op: "Sub" + input: "inputs/random_uniform_8/max" + input: "inputs/random_uniform_8/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_8/mul" + op: "Mul" + input: "inputs/random_uniform_8/RandomUniform" + input: "inputs/random_uniform_8/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_8" + op: "Add" + input: "inputs/random_uniform_8/mul" + input: "inputs/random_uniform_8/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/x_seq_4" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/x_seq_4/Assign" + op: "Assign" + input: "inputs/x_seq_4" + input: "inputs/random_uniform_8" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_4" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/x_seq_4/read" + op: "Identity" + input: "inputs/x_seq_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/x_seq_4" + } + } + } +} +node { + name: "inputs/random_uniform_9/shape" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "inputs/random_uniform_9/min" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "inputs/random_uniform_9/max" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "inputs/random_uniform_9/RandomUniform" + op: "RandomUniform" + input: "inputs/random_uniform_9/shape" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } +} +node { + name: "inputs/random_uniform_9/sub" + op: "Sub" + input: "inputs/random_uniform_9/max" + input: "inputs/random_uniform_9/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_9/mul" + op: "Mul" + input: "inputs/random_uniform_9/RandomUniform" + input: "inputs/random_uniform_9/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/random_uniform_9" + op: "Add" + input: "inputs/random_uniform_9/mul" + input: "inputs/random_uniform_9/min" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "inputs/pad_seq_4" + op: "Variable" + device: "/device:GPU:*" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 128 + } + dim { + size: 1 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "inputs/pad_seq_4/Assign" + op: "Assign" + input: "inputs/pad_seq_4" + input: "inputs/random_uniform_9" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_4" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "inputs/pad_seq_4/read" + op: "Identity" + input: "inputs/pad_seq_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@inputs/pad_seq_4" + } + } + } +} +node { + name: "lstm_0/concat/axis" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_0/concat" + op: "ConcatV2" + input: "inputs/x_seq_0/read" + input: "init_m" + input: "lstm_0/concat/axis" + device: "/device:GPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "lstm_0/MatMul" + op: "MatMul" + input: "lstm_0/concat" + input: "weights/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "lstm_0/split/split_dim" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_0/split" + op: "Split" + input: "lstm_0/split/split_dim" + input: "lstm_0/MatMul" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "num_split" + value { + i: 4 + } + } +} +node { + name: "lstm_0/Tanh" + op: "Tanh" + input: "lstm_0/split" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Sigmoid" + op: "Sigmoid" + input: "lstm_0/split:1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Sigmoid_1" + op: "Sigmoid" + input: "lstm_0/split:2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Sigmoid_2" + op: "Sigmoid" + input: "lstm_0/split:3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul" + op: "Mul" + input: "lstm_0/Sigmoid_1" + input: "init_c" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Minimum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_0/Minimum" + op: "Minimum" + input: "lstm_0/mul" + input: "lstm_0/Minimum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Maximum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_0/Maximum" + op: "Maximum" + input: "lstm_0/Minimum" + input: "lstm_0/Maximum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul_1" + op: "Mul" + input: "lstm_0/Sigmoid" + input: "lstm_0/Tanh" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Minimum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_0/Minimum_1" + op: "Minimum" + input: "lstm_0/mul_1" + input: "lstm_0/Minimum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Maximum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_0/Maximum_1" + op: "Maximum" + input: "lstm_0/Minimum_1" + input: "lstm_0/Maximum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/add" + op: "Add" + input: "lstm_0/Maximum" + input: "lstm_0/Maximum_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Minimum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_0/Minimum_2" + op: "Minimum" + input: "lstm_0/add" + input: "lstm_0/Minimum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Maximum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_0/Maximum_2" + op: "Maximum" + input: "lstm_0/Minimum_2" + input: "lstm_0/Maximum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul_2" + op: "Mul" + input: "lstm_0/Sigmoid_2" + input: "lstm_0/Maximum_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Minimum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_0/Minimum_3" + op: "Minimum" + input: "lstm_0/mul_2" + input: "lstm_0/Minimum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/Maximum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_0/Maximum_3" + op: "Maximum" + input: "lstm_0/Minimum_3" + input: "lstm_0/Maximum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul_3" + op: "Mul" + input: "init_c" + input: "inputs/pad_seq_0/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/sub/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_0/sub" + op: "Sub" + input: "lstm_0/sub/x" + input: "inputs/pad_seq_0/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul_4" + op: "Mul" + input: "lstm_0/Maximum_2" + input: "lstm_0/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/add_1" + op: "Add" + input: "lstm_0/mul_3" + input: "lstm_0/mul_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul_5" + op: "Mul" + input: "init_m" + input: "inputs/pad_seq_0/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/sub_1/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_0/sub_1" + op: "Sub" + input: "lstm_0/sub_1/x" + input: "inputs/pad_seq_0/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/mul_6" + op: "Mul" + input: "lstm_0/Maximum_3" + input: "lstm_0/sub_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/add_2" + op: "Add" + input: "lstm_0/mul_5" + input: "lstm_0/mul_6" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_0/out" + op: "Identity" + input: "lstm_0/add_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/concat/axis" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_1/concat" + op: "ConcatV2" + input: "inputs/x_seq_1/read" + input: "lstm_0/add_2" + input: "lstm_1/concat/axis" + device: "/device:GPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "lstm_1/MatMul" + op: "MatMul" + input: "lstm_1/concat" + input: "weights/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "lstm_1/split/split_dim" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_1/split" + op: "Split" + input: "lstm_1/split/split_dim" + input: "lstm_1/MatMul" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "num_split" + value { + i: 4 + } + } +} +node { + name: "lstm_1/Tanh" + op: "Tanh" + input: "lstm_1/split" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Sigmoid" + op: "Sigmoid" + input: "lstm_1/split:1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Sigmoid_1" + op: "Sigmoid" + input: "lstm_1/split:2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Sigmoid_2" + op: "Sigmoid" + input: "lstm_1/split:3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul" + op: "Mul" + input: "lstm_1/Sigmoid_1" + input: "lstm_0/add_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Minimum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_1/Minimum" + op: "Minimum" + input: "lstm_1/mul" + input: "lstm_1/Minimum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Maximum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_1/Maximum" + op: "Maximum" + input: "lstm_1/Minimum" + input: "lstm_1/Maximum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul_1" + op: "Mul" + input: "lstm_1/Sigmoid" + input: "lstm_1/Tanh" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Minimum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_1/Minimum_1" + op: "Minimum" + input: "lstm_1/mul_1" + input: "lstm_1/Minimum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Maximum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_1/Maximum_1" + op: "Maximum" + input: "lstm_1/Minimum_1" + input: "lstm_1/Maximum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/add" + op: "Add" + input: "lstm_1/Maximum" + input: "lstm_1/Maximum_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Minimum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_1/Minimum_2" + op: "Minimum" + input: "lstm_1/add" + input: "lstm_1/Minimum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Maximum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_1/Maximum_2" + op: "Maximum" + input: "lstm_1/Minimum_2" + input: "lstm_1/Maximum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul_2" + op: "Mul" + input: "lstm_1/Sigmoid_2" + input: "lstm_1/Maximum_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Minimum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_1/Minimum_3" + op: "Minimum" + input: "lstm_1/mul_2" + input: "lstm_1/Minimum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/Maximum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_1/Maximum_3" + op: "Maximum" + input: "lstm_1/Minimum_3" + input: "lstm_1/Maximum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul_3" + op: "Mul" + input: "lstm_0/add_1" + input: "inputs/pad_seq_1/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/sub/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_1/sub" + op: "Sub" + input: "lstm_1/sub/x" + input: "inputs/pad_seq_1/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul_4" + op: "Mul" + input: "lstm_1/Maximum_2" + input: "lstm_1/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/add_1" + op: "Add" + input: "lstm_1/mul_3" + input: "lstm_1/mul_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul_5" + op: "Mul" + input: "lstm_0/add_2" + input: "inputs/pad_seq_1/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/sub_1/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_1/sub_1" + op: "Sub" + input: "lstm_1/sub_1/x" + input: "inputs/pad_seq_1/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/mul_6" + op: "Mul" + input: "lstm_1/Maximum_3" + input: "lstm_1/sub_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/add_2" + op: "Add" + input: "lstm_1/mul_5" + input: "lstm_1/mul_6" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_1/out" + op: "Identity" + input: "lstm_1/add_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/concat/axis" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_2/concat" + op: "ConcatV2" + input: "inputs/x_seq_2/read" + input: "lstm_1/add_2" + input: "lstm_2/concat/axis" + device: "/device:GPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "lstm_2/MatMul" + op: "MatMul" + input: "lstm_2/concat" + input: "weights/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "lstm_2/split/split_dim" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_2/split" + op: "Split" + input: "lstm_2/split/split_dim" + input: "lstm_2/MatMul" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "num_split" + value { + i: 4 + } + } +} +node { + name: "lstm_2/Tanh" + op: "Tanh" + input: "lstm_2/split" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Sigmoid" + op: "Sigmoid" + input: "lstm_2/split:1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Sigmoid_1" + op: "Sigmoid" + input: "lstm_2/split:2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Sigmoid_2" + op: "Sigmoid" + input: "lstm_2/split:3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul" + op: "Mul" + input: "lstm_2/Sigmoid_1" + input: "lstm_1/add_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Minimum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_2/Minimum" + op: "Minimum" + input: "lstm_2/mul" + input: "lstm_2/Minimum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Maximum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_2/Maximum" + op: "Maximum" + input: "lstm_2/Minimum" + input: "lstm_2/Maximum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul_1" + op: "Mul" + input: "lstm_2/Sigmoid" + input: "lstm_2/Tanh" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Minimum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_2/Minimum_1" + op: "Minimum" + input: "lstm_2/mul_1" + input: "lstm_2/Minimum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Maximum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_2/Maximum_1" + op: "Maximum" + input: "lstm_2/Minimum_1" + input: "lstm_2/Maximum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/add" + op: "Add" + input: "lstm_2/Maximum" + input: "lstm_2/Maximum_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Minimum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_2/Minimum_2" + op: "Minimum" + input: "lstm_2/add" + input: "lstm_2/Minimum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Maximum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_2/Maximum_2" + op: "Maximum" + input: "lstm_2/Minimum_2" + input: "lstm_2/Maximum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul_2" + op: "Mul" + input: "lstm_2/Sigmoid_2" + input: "lstm_2/Maximum_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Minimum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_2/Minimum_3" + op: "Minimum" + input: "lstm_2/mul_2" + input: "lstm_2/Minimum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/Maximum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_2/Maximum_3" + op: "Maximum" + input: "lstm_2/Minimum_3" + input: "lstm_2/Maximum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul_3" + op: "Mul" + input: "lstm_1/add_1" + input: "inputs/pad_seq_2/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/sub/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_2/sub" + op: "Sub" + input: "lstm_2/sub/x" + input: "inputs/pad_seq_2/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul_4" + op: "Mul" + input: "lstm_2/Maximum_2" + input: "lstm_2/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/add_1" + op: "Add" + input: "lstm_2/mul_3" + input: "lstm_2/mul_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul_5" + op: "Mul" + input: "lstm_1/add_2" + input: "inputs/pad_seq_2/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/sub_1/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_2/sub_1" + op: "Sub" + input: "lstm_2/sub_1/x" + input: "inputs/pad_seq_2/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/mul_6" + op: "Mul" + input: "lstm_2/Maximum_3" + input: "lstm_2/sub_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/add_2" + op: "Add" + input: "lstm_2/mul_5" + input: "lstm_2/mul_6" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_2/out" + op: "Identity" + input: "lstm_2/add_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/concat/axis" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_3/concat" + op: "ConcatV2" + input: "inputs/x_seq_3/read" + input: "lstm_2/add_2" + input: "lstm_3/concat/axis" + device: "/device:GPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "lstm_3/MatMul" + op: "MatMul" + input: "lstm_3/concat" + input: "weights/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "lstm_3/split/split_dim" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_3/split" + op: "Split" + input: "lstm_3/split/split_dim" + input: "lstm_3/MatMul" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "num_split" + value { + i: 4 + } + } +} +node { + name: "lstm_3/Tanh" + op: "Tanh" + input: "lstm_3/split" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Sigmoid" + op: "Sigmoid" + input: "lstm_3/split:1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Sigmoid_1" + op: "Sigmoid" + input: "lstm_3/split:2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Sigmoid_2" + op: "Sigmoid" + input: "lstm_3/split:3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul" + op: "Mul" + input: "lstm_3/Sigmoid_1" + input: "lstm_2/add_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Minimum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_3/Minimum" + op: "Minimum" + input: "lstm_3/mul" + input: "lstm_3/Minimum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Maximum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_3/Maximum" + op: "Maximum" + input: "lstm_3/Minimum" + input: "lstm_3/Maximum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul_1" + op: "Mul" + input: "lstm_3/Sigmoid" + input: "lstm_3/Tanh" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Minimum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_3/Minimum_1" + op: "Minimum" + input: "lstm_3/mul_1" + input: "lstm_3/Minimum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Maximum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_3/Maximum_1" + op: "Maximum" + input: "lstm_3/Minimum_1" + input: "lstm_3/Maximum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/add" + op: "Add" + input: "lstm_3/Maximum" + input: "lstm_3/Maximum_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Minimum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_3/Minimum_2" + op: "Minimum" + input: "lstm_3/add" + input: "lstm_3/Minimum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Maximum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_3/Maximum_2" + op: "Maximum" + input: "lstm_3/Minimum_2" + input: "lstm_3/Maximum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul_2" + op: "Mul" + input: "lstm_3/Sigmoid_2" + input: "lstm_3/Maximum_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Minimum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_3/Minimum_3" + op: "Minimum" + input: "lstm_3/mul_2" + input: "lstm_3/Minimum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/Maximum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_3/Maximum_3" + op: "Maximum" + input: "lstm_3/Minimum_3" + input: "lstm_3/Maximum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul_3" + op: "Mul" + input: "lstm_2/add_1" + input: "inputs/pad_seq_3/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/sub/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_3/sub" + op: "Sub" + input: "lstm_3/sub/x" + input: "inputs/pad_seq_3/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul_4" + op: "Mul" + input: "lstm_3/Maximum_2" + input: "lstm_3/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/add_1" + op: "Add" + input: "lstm_3/mul_3" + input: "lstm_3/mul_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul_5" + op: "Mul" + input: "lstm_2/add_2" + input: "inputs/pad_seq_3/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/sub_1/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_3/sub_1" + op: "Sub" + input: "lstm_3/sub_1/x" + input: "inputs/pad_seq_3/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/mul_6" + op: "Mul" + input: "lstm_3/Maximum_3" + input: "lstm_3/sub_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/add_2" + op: "Add" + input: "lstm_3/mul_5" + input: "lstm_3/mul_6" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_3/out" + op: "Identity" + input: "lstm_3/add_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/concat/axis" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_4/concat" + op: "ConcatV2" + input: "inputs/x_seq_4/read" + input: "lstm_3/add_2" + input: "lstm_4/concat/axis" + device: "/device:GPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "lstm_4/MatMul" + op: "MatMul" + input: "lstm_4/concat" + input: "weights/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "lstm_4/split/split_dim" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "lstm_4/split" + op: "Split" + input: "lstm_4/split/split_dim" + input: "lstm_4/MatMul" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "num_split" + value { + i: 4 + } + } +} +node { + name: "lstm_4/Tanh" + op: "Tanh" + input: "lstm_4/split" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Sigmoid" + op: "Sigmoid" + input: "lstm_4/split:1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Sigmoid_1" + op: "Sigmoid" + input: "lstm_4/split:2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Sigmoid_2" + op: "Sigmoid" + input: "lstm_4/split:3" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul" + op: "Mul" + input: "lstm_4/Sigmoid_1" + input: "lstm_3/add_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Minimum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_4/Minimum" + op: "Minimum" + input: "lstm_4/mul" + input: "lstm_4/Minimum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Maximum/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_4/Maximum" + op: "Maximum" + input: "lstm_4/Minimum" + input: "lstm_4/Maximum/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul_1" + op: "Mul" + input: "lstm_4/Sigmoid" + input: "lstm_4/Tanh" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Minimum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_4/Minimum_1" + op: "Minimum" + input: "lstm_4/mul_1" + input: "lstm_4/Minimum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Maximum_1/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_4/Maximum_1" + op: "Maximum" + input: "lstm_4/Minimum_1" + input: "lstm_4/Maximum_1/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/add" + op: "Add" + input: "lstm_4/Maximum" + input: "lstm_4/Maximum_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Minimum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_4/Minimum_2" + op: "Minimum" + input: "lstm_4/add" + input: "lstm_4/Minimum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Maximum_2/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_4/Maximum_2" + op: "Maximum" + input: "lstm_4/Minimum_2" + input: "lstm_4/Maximum_2/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul_2" + op: "Mul" + input: "lstm_4/Sigmoid_2" + input: "lstm_4/Maximum_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Minimum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_4/Minimum_3" + op: "Minimum" + input: "lstm_4/mul_2" + input: "lstm_4/Minimum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/Maximum_3/y" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -1.0 + } + } + } +} +node { + name: "lstm_4/Maximum_3" + op: "Maximum" + input: "lstm_4/Minimum_3" + input: "lstm_4/Maximum_3/y" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul_3" + op: "Mul" + input: "lstm_3/add_1" + input: "inputs/pad_seq_4/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/sub/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_4/sub" + op: "Sub" + input: "lstm_4/sub/x" + input: "inputs/pad_seq_4/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul_4" + op: "Mul" + input: "lstm_4/Maximum_2" + input: "lstm_4/sub" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/add_1" + op: "Add" + input: "lstm_4/mul_3" + input: "lstm_4/mul_4" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul_5" + op: "Mul" + input: "lstm_3/add_2" + input: "inputs/pad_seq_4/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/sub_1/x" + op: "Const" + device: "/device:GPU:*" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } +} +node { + name: "lstm_4/sub_1" + op: "Sub" + input: "lstm_4/sub_1/x" + input: "inputs/pad_seq_4/read" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/mul_6" + op: "Mul" + input: "lstm_4/Maximum_3" + input: "lstm_4/sub_1" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/add_2" + op: "Add" + input: "lstm_4/mul_5" + input: "lstm_4/mul_6" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "lstm_4/out" + op: "Identity" + input: "lstm_4/add_2" + device: "/device:GPU:*" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 19 +} diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py new file mode 100644 index 0000000000..9ffeb6c2a2 --- /dev/null +++ b/tensorflow/compiler/tests/lstm_test.py @@ -0,0 +1,293 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the LSTM cell and layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.compiler.tests import lstm +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import flags as flags_lib +from tensorflow.python.platform import test + +flags = flags_lib +FLAGS = flags.FLAGS + +flags.DEFINE_integer('batch_size', 128, + 'Inputs are fed in batches of this size, for both ' + 'inference and training. Larger values cause the matmul ' + 'in each LSTM cell to have higher dimensionality.') +flags.DEFINE_integer('seq_length', 60, + 'Length of the unrolled sequence of LSTM cells in a layer.' + 'Larger values cause more LSTM matmuls to be run.') +flags.DEFINE_integer('num_inputs', 1024, + 'Dimension of inputs that are fed into each LSTM cell.') +flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.') +flags.DEFINE_string('device', 'gpu', + 'TensorFlow device to assign ops to, e.g. "gpu", "cpu". ' + 'For details see documentation for tf.Graph.device.') + +flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in ' + '*.pbtxt format to this directory.') + + +def _DumpGraph(graph, basename): + if FLAGS.dump_graph_dir: + name = os.path.join(FLAGS.dump_graph_dir, basename + '.pbtxt') + with open(name, 'w') as f: + f.write(str(graph.as_graph_def())) + + +def _Sigmoid(x): + return 1. / (1. + np.exp(-x)) + + +def _Clip(x): + return np.maximum(np.minimum(x, 1.), -1.) + + +class LSTMTest(test.TestCase): + + def setUp(self): + # The tests for a single LSTM cell and LSTM layer use these values as + # inputs. We always set the dimensionality of num_inputs=1; thus batch_size + # actually represents the different input cases. + self._inputs = np.array([[-1.], [-.5], [0.], [.5], [1.]], np.float32) + self._batch_size = len(self._inputs) + + def _NextC(self, inputs, weight, m_prev, c_prev): + """Returns the next c states of an LSTM cell.""" + x = (inputs + m_prev) * weight + return _Clip(_Clip(_Sigmoid(x) * c_prev) + _Clip(_Sigmoid(x) * np.tanh(x))) + + def _NextM(self, inputs, weight, m_prev, c_prev): + """Returns the next m states of an LSTM cell.""" + x = (inputs + m_prev) * weight + return _Clip(_Sigmoid(x) * self._NextC(inputs, weight, m_prev, c_prev)) + + def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar, + pad_scalar): + with self.test_session() as sess: + num_inputs = 1 + num_nodes = 1 + + weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes)) + m_prev = constant_op.constant([[m_prev_scalar]] * self._batch_size) + c_prev = constant_op.constant([[c_prev_scalar]] * self._batch_size) + x = constant_op.constant(self._inputs) + pad = constant_op.constant([[pad_scalar]] * self._batch_size) + + m, c = lstm.LSTMCell(weights, m_prev, c_prev, x, pad) + _DumpGraph(sess.graph, 'lstm_cell_%s_%d_%d_%d' % + (basename, m_prev_scalar, c_prev_scalar, pad_scalar)) + + # Initialize variables and run the unrolled LSTM step. + sess.run(variables.global_variables_initializer()) + return sess.run([m, c]) + + def testLSTMCell(self): + # Run with all-0 weights, no padding. + m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 0., 0.) + self.assertAllClose(m, [[0.]] * self._batch_size) + self.assertAllClose(c, [[0.]] * self._batch_size) + m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 1., 0.) + self.assertAllClose(m, [[.25]] * self._batch_size) + self.assertAllClose(c, [[.5]] * self._batch_size) + m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 0., 0.) + self.assertAllClose(m, [[.0]] * self._batch_size) + self.assertAllClose(c, [[.0]] * self._batch_size) + m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 1., 0.) + self.assertAllClose(m, [[.25]] * self._batch_size) + self.assertAllClose(c, [[.5]] * self._batch_size) + + # Run with all-1 weights, no padding. + for m_prev in [0., 1.]: + for c_prev in [0., 1.]: + m, c = self._RunLSTMCell('ones', + init_ops.ones_initializer(), m_prev, c_prev, + 0.) + self.assertAllClose(m, self._NextM(self._inputs, 1., m_prev, c_prev)) + self.assertAllClose(c, self._NextC(self._inputs, 1., m_prev, c_prev)) + + # Run with random weights. + for weight in np.random.rand(3): + weight_tf = constant_op.constant(weight, dtypes.float32) + random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w) + + # No padding. + for m_prev in [0., 1.]: + for c_prev in [0., 1.]: + m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 0.) + self.assertAllClose(m, + self._NextM(self._inputs, weight, m_prev, c_prev)) + self.assertAllClose(c, + self._NextC(self._inputs, weight, m_prev, c_prev)) + + # Set padding. + for m_prev in [0., 1.]: + for c_prev in [0., 1.]: + m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 1.) + self.assertAllClose(m, [[m_prev]] * self._batch_size) + self.assertAllClose(c, [[c_prev]] * self._batch_size) + + def testLSTMLayerErrors(self): + num_inputs = 1 + num_nodes = 1 + seq_length = 3 + + weights = array_ops.zeros(lstm.LSTMCellWeightsShape(num_inputs, num_nodes)) + m = constant_op.constant([[0.]] * self._batch_size) + c = constant_op.constant([[0.]] * self._batch_size) + x_seq = [constant_op.constant(self._inputs)] * seq_length + pad = constant_op.constant([[0.]] * self._batch_size) + + with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'): + lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad]) + with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'): + lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 2) + with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'): + lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 4) + + def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar, + pad_scalar): + with self.test_session() as sess: + num_inputs = 1 + num_nodes = 1 + seq_length = 3 + + weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes)) + m_init = constant_op.constant([[m_init_scalar]] * self._batch_size) + c_init = constant_op.constant([[c_init_scalar]] * self._batch_size) + x_seq = [constant_op.constant(self._inputs)] * seq_length + pad_seq = [constant_op.constant([[pad_scalar]] * self._batch_size) + ] * seq_length + + out_seq = lstm.LSTMLayer('lstm', weights, m_init, c_init, x_seq, pad_seq) + _DumpGraph(sess.graph, 'lstm_layer_%s_%d_%d_%d' % + (basename, m_init_scalar, c_init_scalar, pad_scalar)) + + # Initialize variables and run the unrolled LSTM layer. + sess.run(variables.global_variables_initializer()) + return sess.run(out_seq) + + def testLSTMLayer(self): + # Run with all-0 weights, no padding. + o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 0., 0.) + self.assertAllClose(o, [[[0.]] * self._batch_size] * 3) + o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 1., 0.) + self.assertAllClose(o, [[[.25]] * self._batch_size, + [[.125]] * self._batch_size, + [[.0625]] * self._batch_size]) + o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 0., 0.) + self.assertAllClose(o, [[[0.]] * self._batch_size] * 3) + o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 1., 0.) + self.assertAllClose(o, [[[.25]] * self._batch_size, + [[.125]] * self._batch_size, + [[.0625]] * self._batch_size]) + + # Run with all-1 weights, no padding. + weight1 = 1. + for m_init in [0., 1.]: + for c_init in [0., 1.]: + o = self._RunLSTMLayer('ones', + init_ops.ones_initializer(), m_init, c_init, 0.) + m0 = self._NextM(self._inputs, weight1, m_init, c_init) + c0 = self._NextC(self._inputs, weight1, m_init, c_init) + self.assertAllClose(o[0], m0) + m1 = self._NextM(self._inputs, weight1, m0, c0) + c1 = self._NextC(self._inputs, weight1, m0, c0) + self.assertAllClose(o[1], m1) + m2 = self._NextM(self._inputs, weight1, m1, c1) + self.assertAllClose(o[2], m2) + + # Run with random weights. + for weight in np.random.rand(3): + weight_tf = constant_op.constant(weight, dtypes.float32) + random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w) + + # No padding. + for m_init in [0., 1.]: + for c_init in [0., 1.]: + o = self._RunLSTMLayer('random', random_weight, m_init, c_init, 0.) + m0 = self._NextM(self._inputs, weight, m_init, c_init) + c0 = self._NextC(self._inputs, weight, m_init, c_init) + self.assertAllClose(o[0], m0) + m1 = self._NextM(self._inputs, weight, m0, c0) + c1 = self._NextC(self._inputs, weight, m0, c0) + self.assertAllClose(o[1], m1) + m2 = self._NextM(self._inputs, weight, m1, c1) + self.assertAllClose(o[2], m2) + + # Set padding. + o = self._RunLSTMLayer('random', random_weight, 0., 0., 1.) + self.assertAllClose(o, [[[0.]] * self._batch_size] * 3) + o = self._RunLSTMLayer('random', random_weight, 0., 1., 1.) + self.assertAllClose(o, [[[0.]] * self._batch_size] * 3) + o = self._RunLSTMLayer('random', random_weight, 1., 0., 1.) + self.assertAllClose(o, [[[1.]] * self._batch_size] * 3) + o = self._RunLSTMLayer('random', random_weight, 1., 1., 1.) + self.assertAllClose(o, [[[1.]] * self._batch_size] * 3) + + +class LSTMBenchmark(test.Benchmark): + """Mcro-benchmarks for a single layer of LSTM cells.""" + + def _LayerBuilder(self, do_training): + out_seq, weights = lstm.BuildLSTMLayer(FLAGS.batch_size, FLAGS.seq_length, + FLAGS.num_inputs, FLAGS.num_nodes) + name, fetches = ('lstm_layer_inference', out_seq) + if do_training: + # Not a real loss function, but good enough for benchmarking backprop. + loss = math_ops.reduce_sum(math_ops.add_n(out_seq)) + dw = gradients_impl.gradients(loss, weights) + name, fetches = ('lstm_layer_training', dw) + + _DumpGraph(ops.get_default_graph(), + '%s_%d_%d_%d_%d' % (name, FLAGS.batch_size, FLAGS.seq_length, + FLAGS.num_inputs, FLAGS.num_nodes)) + return name, fetches + + def benchmarkLayerInference(self): + xla_test.Benchmark(self, lambda: self._LayerBuilder(False), False, + FLAGS.device) + + def benchmarkLayerInferenceXLA(self): + xla_test.Benchmark(self, lambda: self._LayerBuilder(False), True, + FLAGS.device) + + def benchmarkLayerTraining(self): + xla_test.Benchmark(self, lambda: self._LayerBuilder(True), False, + FLAGS.device) + + def benchmarkLayerTrainingXLA(self): + xla_test.Benchmark(self, lambda: self._LayerBuilder(True), True, + FLAGS.device) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py new file mode 100644 index 0000000000..566c02e72f --- /dev/null +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -0,0 +1,209 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for operators with > 3 or arbitrary numbers of arguments.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class NAryOpsTest(XLATestCase): + + def _testNAry(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(placeholders) + result = session.run(output, feeds) + self.assertAllClose(result, expected, rtol=1e-3) + + def testFloat(self): + self._testNAry(math_ops.add_n, + [np.array([[1, 2, 3]], dtype=np.float32)], + expected=np.array([[1, 2, 3]], dtype=np.float32)) + + self._testNAry(math_ops.add_n, + [np.array([1, 2], dtype=np.float32), + np.array([10, 20], dtype=np.float32)], + expected=np.array([11, 22], dtype=np.float32)) + self._testNAry(math_ops.add_n, + [np.array([-4], dtype=np.float32), + np.array([10], dtype=np.float32), + np.array([42], dtype=np.float32)], + expected=np.array([48], dtype=np.float32)) + + def testConcat(self): + self._testNAry( + lambda x: array_ops.concat_v2(x, 0), [ + np.array( + [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array( + [[7, 8, 9], [10, 11, 12]], dtype=np.float32) + ], + expected=np.array( + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32)) + + self._testNAry( + lambda x: array_ops.concat_v2(x, 1), [ + np.array( + [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array( + [[7, 8, 9], [10, 11, 12]], dtype=np.float32) + ], + expected=np.array( + [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) + + def testSplitV(self): + with self.test_session() as session: + with self.test_scope(): + output = session.run( + array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], + dtype=np.float32), + [2, 2], 1)) + expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32), + np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)] + self.assertAllEqual(output, expected) + + def testStridedSlice(self): + self._testNAry(lambda x: array_ops.strided_slice(*x), + [np.array([[], [], []], dtype=np.float32), + np.array([1, 0], dtype=np.int32), + np.array([3, 0], dtype=np.int32), + np.array([1, 1], dtype=np.int32)], + expected=np.array([[], []], dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice(*x), + [np.array([[], [], []], dtype=np.float32), + np.array([1, 0], dtype=np.int64), + np.array([3, 0], dtype=np.int64), + np.array([1, 1], dtype=np.int64)], + expected=np.array([[], []], dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice(*x), + [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=np.float32), + np.array([1, 1], dtype=np.int32), + np.array([3, 3], dtype=np.int32), + np.array([1, 1], dtype=np.int32)], + expected=np.array([[5, 6], [8, 9]], dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice(*x), + [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=np.float32), + np.array([0, 2], dtype=np.int32), + np.array([2, 0], dtype=np.int32), + np.array([1, -1], dtype=np.int32)], + expected=np.array([[3, 2], [6, 5]], dtype=np.float32)) + + self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1], + [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=np.float32)], + expected=np.array([[[3, 2, 1]], [[6, 5, 4]]], + dtype=np.float32)) + + self._testNAry(lambda x: x[0][1, :, array_ops.newaxis], + [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=np.float32)], + expected=np.array([[4], [5], [6]], dtype=np.float32)) + + def testStridedSliceGrad(self): + # Tests cases where input shape is empty. + self._testNAry(lambda x: array_ops.strided_slice_grad(*x), + [np.array([], dtype=np.int32), + np.array([], dtype=np.int32), + np.array([], dtype=np.int32), + np.array([], dtype=np.int32), + np.float32(0.5)], + expected=np.array(np.float32(0.5), dtype=np.float32)) + + # Tests case where input shape is non-empty, but gradients are empty. + self._testNAry(lambda x: array_ops.strided_slice_grad(*x), + [np.array([3], dtype=np.int32), + np.array([0], dtype=np.int32), + np.array([0], dtype=np.int32), + np.array([1], dtype=np.int32), + np.array([], dtype=np.float32)], + expected=np.array([0, 0, 0], dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice_grad(*x), + [np.array([3, 0], dtype=np.int32), + np.array([1, 0], dtype=np.int32), + np.array([3, 0], dtype=np.int32), + np.array([1, 1], dtype=np.int32), + np.array([[], []], dtype=np.float32)], + expected=np.array([[], [], []], dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice_grad(*x), + [np.array([3, 3], dtype=np.int32), + np.array([1, 1], dtype=np.int32), + np.array([3, 3], dtype=np.int32), + np.array([1, 1], dtype=np.int32), + np.array([[5, 6], [8, 9]], dtype=np.float32)], + expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]], + dtype=np.float32)) + + def ssg_test(x): + return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4, + new_axis_mask=0x1) + + self._testNAry(ssg_test, + [np.array([3, 1, 3], dtype=np.int32), + np.array([0, 0, 0, 2], dtype=np.int32), + np.array([0, 3, 1, -4], dtype=np.int32), + np.array([1, 2, 1, -3], dtype=np.int32), + np.array([[[1], [2]]], dtype=np.float32)], + expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]], + dtype=np.float32)) + + ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15) + self._testNAry(ssg_test2, + [np.array([4, 4], dtype=np.int32), + np.array([0, 0, 0, 1, 0], dtype=np.int32), + np.array([0, 3, 0, 4, 0], dtype=np.int32), + np.array([1, 2, 1, 2, 1], dtype=np.int32), + np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)], + expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4], + [0, 0, 0, 0]], dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice_grad(*x), + [np.array([3, 3], dtype=np.int32), + np.array([0, 2], dtype=np.int32), + np.array([2, 0], dtype=np.int32), + np.array([1, -1], dtype=np.int32), + np.array([[1, 2], [3, 4]], dtype=np.float32)], + expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]], + dtype=np.float32)) + + self._testNAry(lambda x: array_ops.strided_slice_grad(*x), + [np.array([3, 3], dtype=np.int32), + np.array([2, 2], dtype=np.int32), + np.array([0, 1], dtype=np.int32), + np.array([-1, -2], dtype=np.int32), + np.array([[1], [2]], dtype=np.float32)], + expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]], + dtype=np.float32)) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py new file mode 100644 index 0000000000..6f588d8ab5 --- /dev/null +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for operators with no arguments.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.platform import googletest + + +class NullaryOpsTest(XLATestCase): + + def _testNullary(self, op, expected): + with self.test_session() as session: + with self.test_scope(): + output = op() + result = session.run(output) + self.assertAllClose(result, expected, rtol=1e-3) + + def testNoOp(self): + with self.test_session(): + with self.test_scope(): + output = control_flow_ops.no_op() + # This should not crash. + output.run() + + def testConstants(self): + constants = [ + np.float32(42), + np.array([], dtype=np.float32), + np.array([1, 2], dtype=np.float32), + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), + np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], + dtype=np.float32), + np.array([[[]], [[]]], dtype=np.float32), + np.array([[[[1]]]], dtype=np.float32), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py new file mode 100644 index 0000000000..52290e6354 --- /dev/null +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -0,0 +1,511 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for pooling operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +def NHWCToNCHW(input_tensor): + """Convert the input from NHWC format to NCHW. + + Args: + input_tensor: a 4-D tensor, or a 4-element array representing the same. + + Returns: + the converted tensor or a shape array + """ + if isinstance(input_tensor, ops.Tensor): + return array_ops.transpose(input_tensor, [0, 3, 1, 2]) + else: + return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]] + + +def NCHWToNHWC(input_tensor): + """Convert the input from NCHW format to NHWC. + + Args: + input_tensor: a 4-D tensor, or a 4-element array representing the same. + + Returns: + the converted tensor or a shape array + """ + if isinstance(input_tensor, ops.Tensor): + return array_ops.transpose(input_tensor, [0, 2, 3, 1]) + else: + return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]] + + +def GetTestConfigs(): + """Get all the valid tests configs to run. + + Returns: + all the valid test configs + """ + test_configs = ["NHWC", "NCHW"] + return test_configs + + +class PoolingTest(XLATestCase): + + def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, + data_format, expected): + """Verifies the output values of the pooling function. + + Args: + pool_func: Function to be called, currently only co.MaxPool. + input_sizes: Input tensor dimensions. + ksize: The kernel size dimensions + strides: The stride dimensions + padding: Padding type. + data_format: The data format we use to run the pooling operation. + expected: An array containing the expected operation outputs. + """ + total_size = np.prod(input_sizes) + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) + x = x.reshape(input_sizes) + with self.test_session() as sess: + with self.test_scope(): + inputs = array_ops.placeholder(dtypes.float32) + t = inputs + if data_format == "NCHW": + t = NHWCToNCHW(t) + ksize = NHWCToNCHW(ksize) + strides = NHWCToNCHW(strides) + t = pool_func(t, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + t = NCHWToNHWC(t) + actual = sess.run(t, {inputs: x}) + self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6) + + def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding, + expected): + """Verifies the output values of the pooling function. + + Args: + pool_func: Function to be called, co.MaxPool, co.AvgPool, + or the Lua version. + input_sizes: Input tensor dimensions. + ksize: The kernel size dimensions + strides: The stride dimensions + padding: Padding type. + expected: An array containing the expected operation outputs. + """ + for data_format in GetTestConfigs(): + self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding, + data_format, expected) + + def testMaxPoolValidPadding(self): + expected_output = [13.0, 14.0, 15.0] + self._VerifyValues(nn_ops.max_pool, + input_sizes=[1, 3, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="VALID", + expected=expected_output) + + def testMaxPoolSamePadding(self): + expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0] + self._VerifyValues(nn_ops.max_pool, + input_sizes=[1, 2, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output) + + def testMaxPoolSamePaddingNonSquareWindow(self): + # input is: + # [1.0, 2.0 + # 3.0 4.0] + # + # Window of [x, x] should do: + # + # [max(1.0, 2.0), max(2.0, padded0), + # max(3.0, 4.0), max(4.0, padded0)] + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 2, 2, 1], + ksize=[1, 1, 2, 1], + strides=[1, 1, 1, 1], + padding="SAME", + expected=[2.0, 2.0, 4.0, 4.0]) + + def testMaxPoolValidPaddingUnevenStride(self): + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 4, 4, 1], + ksize=[1, 2, 2, 1], + strides=[1, 1, 2, 1], + padding="VALID", + expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0]) + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 4, 4, 1], + ksize=[1, 2, 2, 1], + strides=[1, 2, 1, 1], + padding="VALID", + expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0]) + + def testMaxPoolSamePaddingFilter4(self): + expected_output = [ + 21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0, + 61.0, 62.0, 63.0, 64.0 + ] + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 4, 4, 4], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output) + + def testMaxPoolSamePaddingFilter8(self): + expected_output = [ + 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0, + 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0, + 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, + 191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, + 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0, + 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0, + 317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0, + 407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, + 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, + 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0, + 469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, + 487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, + 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0 + ] + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 8, 8, 8], + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output) + + # Tests for DepthwiseMaxPooling on CPU only. + def testDepthwiseMaxPool1x1DepthWindow1(self): + # input is: + # [1.0, ..., 10.0] along depth, + # + # We maxpool by depth in patches of 2. + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 1, 1, 10], + ksize=[1, 1, 1, 2], + strides=[1, 1, 1, 2], + padding="SAME", + expected=[2.0, 4.0, 6.0, 8.0, 10.0]) + + def testDepthwiseMaxPool2x2DepthWindow3(self): + # input is: + # + # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2 + # output. Each node has contiguous values, so the depthwise max + # should be multiples of 3.0. + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 2, 2, 6], + ksize=[1, 1, 1, 3], + strides=[1, 1, 1, 3], + padding="SAME", + expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0]) + + def testKernelSmallerThanStrideValid(self): + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 7, 7, 1], + ksize=[1, 2, 2, 1], + strides=[1, 3, 3, 1], + padding="VALID", + expected=[9, 12, 30, 33]) + + def testKernelSmallerThanStrideSame(self): + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 3, 3, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[1, 3, 7, 9]) + + self._VerifyValues( + nn_ops.max_pool, + input_sizes=[1, 4, 4, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[1, 3, 9, 11]) + + # Average pooling + def testAvgPoolValidPadding(self): + expected_output = [7, 8, 9] + self._VerifyValues( + nn_ops.avg_pool, + input_sizes=[1, 3, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="VALID", + expected=expected_output) + + def testAvgPoolSamePadding(self): + expected_output = [7., 8., 9., 11.5, 12.5, 13.5] + self._VerifyValues( + nn_ops.avg_pool, + input_sizes=[1, 2, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=expected_output) + + +class PoolGradTest(XLATestCase): + + CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" + + def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize, + strides, padding, data_format): + """Verifies the output values of the pooling gradient function. + + Args: + pool_func: Forward pooling function + pool_grad_func: Pooling gradient function for pool_grad_func + input_sizes: Input tensor dimensions. + ksize: The kernel size dimensions + strides: The stride dimensions + padding: Padding type. + data_format: The data format we use to run the pooling operation. + """ + total_size = np.prod(input_sizes) + x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) + with self.test_session() as sess: + # Use the forward pool function to compute some corresponding outputs + # (needed for the CPU device, and we need the shape in both cases). + with ops.device(self.CPU_DEVICE): + inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes) + outputs = pool_func( + inputs, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NHWC") + + output_vals = np.array(sess.run(outputs, {inputs: x})) + output_gradient_vals = np.arange( + 1, output_vals.size + 1, dtype=np.float32) + output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + + # Use the Tensorflow CPU pooling gradient to compute the expected input + # gradients. + with ops.device(self.CPU_DEVICE): + output_gradients = array_ops.placeholder( + dtypes.float32, shape=output_vals.shape) + expected_input_gradients = pool_grad_func( + inputs, + outputs, + output_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NHWC") + expected_input_gradient_vals = sess.run( + expected_input_gradients, + {inputs: x, + output_gradients: output_gradient_vals}) + + # Run the gradient op on the XLA device + with self.test_scope(): + outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) + xla_inputs = inputs + xla_outputs = outputs + xla_output_gradients = output_gradients + xla_ksize = ksize + xla_strides = strides + if data_format == "NCHW": + xla_inputs = NHWCToNCHW(inputs) + xla_outputs = NHWCToNCHW(outputs) + xla_output_gradients = NHWCToNCHW(output_gradients) + xla_ksize = NHWCToNCHW(ksize) + xla_strides = NHWCToNCHW(strides) + actual_input_gradients = pool_grad_func( + xla_inputs, + xla_outputs, + xla_output_gradients, + ksize=xla_ksize, + strides=xla_strides, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + actual_input_gradients = NCHWToNHWC(actual_input_gradients) + actual = sess.run(actual_input_gradients, { + inputs: x, + outputs: output_vals, + output_gradients: output_gradient_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_input_gradient_vals.flatten(), + actual.flatten(), + rtol=1e-5, + atol=1e-6) + self.assertShapeEqual(actual, inputs) + + def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize, + strides, padding): + """Verifies the output values of the pooling function. + + Args: + pool_func: Pooling function to be called, e.g., tf.nn.max_pool + pool_grad_func: Corresponding pooling gradient function. + input_sizes: Input tensor dimensions. + ksize: The kernel size dimensions + strides: The stride dimensions + padding: Padding type. + """ + for data_format in GetTestConfigs(): + self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize, + strides, padding, data_format) + + def _TestPooling(self, forward_op, backward_op): + # VALID padding + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 3, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="VALID") + + # SAME padding + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 2, 3, 3], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME") + + # SAME padding, non square window + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 2, 2, 1], + ksize=[1, 1, 2, 1], + strides=[1, 1, 1, 1], + padding="SAME") + + # VALID padding, uneven stride + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 4, 4, 1], + ksize=[1, 2, 2, 1], + strides=[1, 1, 2, 1], + padding="VALID") + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 4, 4, 1], + ksize=[1, 2, 2, 1], + strides=[1, 2, 1, 1], + padding="VALID") + + # SAME padding, size 4 input + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 4, 4, 4], + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding="SAME") + + # SAME padding, size 8 input + self._VerifyValues( + forward_op, + backward_op, + input_sizes=[1, 8, 8, 8], + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding="SAME") + + def testMaxPool(self): + self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad) + + def testAvgPool(self): + # Wrapper around AvgPoolGrad that ignores extra arguments needed by + # MaxPoolGrad. + def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding, + data_format): + del outputs # Unused by average-pooling gradients. + return gen_nn_ops._avg_pool_grad( + inputs.get_shape().as_list(), + output_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format) + + self._TestPooling(nn_ops.avg_pool, AvgPoolGrad) + + # The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than + # the stride size, so we only run the following tests on MaxPoolGrad. + + def testMaxPoolKernelSmallerThanStrideValid(self): + self._VerifyValues( + nn_ops.max_pool, + gen_nn_ops._max_pool_grad, + input_sizes=[1, 7, 7, 1], + ksize=[1, 2, 2, 1], + strides=[1, 3, 3, 1], + padding="VALID") + + def testMaxPoolKernelSmallerThanStrideSame(self): + self._VerifyValues( + nn_ops.max_pool, + gen_nn_ops._max_pool_grad, + input_sizes=[1, 3, 3, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME") + + self._VerifyValues( + nn_ops.max_pool, + gen_nn_ops._max_pool_grad, + input_sizes=[1, 4, 4, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME") + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc new file mode 100644 index 0000000000..41403858a6 --- /dev/null +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -0,0 +1,2097 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Randomized tests for XLA implementations of Tensorflow operations. +// +// For each operator, the tests in this file choose a set of random inputs and +// attributes. The test then compares the outputs of the operator when executed +// via Tensorflow using the CPU device and when executed via XLA. +// +// By default, each test chooses a random seed nondeterministically (using +// std::random_device). However, a particular choice of random seed can be +// forced using the flag --tf_xla_random_seed; each test logs the +// flag value necessary to reproduce its outputs. +// +// Example usage: +// Run tests, comparing the Tensorflow CPU operators with their XLA-compiled +// counterparts: +// randomized_tests \ +// --tf_xla_test_use_jit=true --tf_xla_test_device=CPU \ +// --tf_xla_test_repetitions=20 + +// TODO(phawkins): add tests for: +// * ArgMax +// * DepthwiseConv2DNative +// * Gather +// * InvertPermutation +// * MaxPoolGrad (requires implementation of forward operator) +// * Select +// * Unpack +// +// TODO(phawkins): improve tests for: +// * StridedSliceGrad (need to use shape function to compute sensible inputs) + +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +// Command line flags: see main() below. +int64 tf_xla_random_seed = 0; +int32 tf_xla_test_repetitions = 20; +string* tf_xla_test_device_ptr; // initial value set in main() +bool tf_xla_test_use_jit = true; + +string DeviceTypeToDeviceName(DeviceType type) { + return strings::StrCat("/job:localhost/replica:0/task:0/device:", type.type(), + ":0"); +} + +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL}}; + +// An OpTestBuilder is a graph builder class that takes as input an operator to +// test, its inputs and attributes, and builds a graph that executes the +// operator. +class OpTestBuilder { + public: + explicit OpTestBuilder(const string& op_name); + + // Adds an input 'tensor'. + OpTestBuilder& Input(Tensor tensor); + + // Sets an attribute. + template + OpTestBuilder& Attr(StringPiece attr_name, T&& value); + + // Overload needed to allow {...} expressions for value. + template + OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list value); + + // Adds nodes that executes the operator under test on 'device' to 'graphdef'. + // If 'use_jit' is true, marks the operator under test to be compiled by XLA. + // The graph will consist of one Placeholder node per input, the operator + // itself, and one Identity node per output. If 'test_node_def' is not null, + // sets it to the NodeDef of the operator under test. Fills 'inputs' and + // 'outputs' with the names of the input placeholder nodes and the output + // identity nodes, respectively. + Status BuildGraph(string name_prefix, string device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, + std::vector* inputs, + std::vector* outputs) const; + + const std::vector& inputs() const { return inputs_; } + + private: + NodeDef node_def_; + std::vector inputs_; +}; + +OpTestBuilder::OpTestBuilder(const string& op_name) { + node_def_.set_op(op_name); +} + +OpTestBuilder& OpTestBuilder::Input(Tensor tensor) { + VLOG(1) << "Adding input: " << tensor.DebugString(); + inputs_.push_back(tensor); + return *this; +} + +template +OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { + AddNodeAttr(attr_name, std::forward(value), &node_def_); + return *this; +} + +template +OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, + std::initializer_list value) { + Attr>(attr_name, std::move(value)); + return *this; +} + +Status OpTestBuilder::BuildGraph(string name_prefix, string device, + bool use_jit, GraphDef* graphdef, + NodeDef** test_node_def, + std::vector* inputs, + std::vector* outputs) const { + OpRegistryInterface* op_registry = OpRegistry::Global(); + + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def_.op(), &op_def)); + + NodeDef* test_def = graphdef->add_node(); + *test_def = node_def_; + test_def->set_name(strings::StrCat(name_prefix, "_op_under_test")); + test_def->set_device(device); + AddDefaultsToNodeDef(*op_def, test_def); + if (use_jit) { + AddNodeAttr(kXlaCompileAttr, true, test_def); + } + VLOG(1) << "Op under test: " << test_def->DebugString(); + + DataTypeVector input_types, output_types; + TF_RETURN_IF_ERROR( + InOutTypesForNode(*test_def, *op_def, &input_types, &output_types)); + + // Build feed and fetch nodes. + for (int i = 0; i < input_types.size(); ++i) { + NodeDef* def = graphdef->add_node(); + string name = strings::StrCat(name_prefix, "_input_", i); + TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") + .Device(device) + .Attr("dtype", input_types[i]) + .Finalize(def)); + inputs->push_back(name); + test_def->add_input(name); + } + + for (int i = 0; i < output_types.size(); ++i) { + NodeDef* def = graphdef->add_node(); + string name = strings::StrCat(name_prefix, "_output_", i); + TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") + .Device(device) + .Attr("T", output_types[i]) + .Input(test_def->name(), i, output_types[i]) + .Finalize(def)); + outputs->push_back(name); + } + + if (test_node_def) { + *test_node_def = test_def; + } + + return Status::OK(); +} + +// Test fixture. The fixture manages the random number generator and its seed, +// and has a number of convenience methods for building random Tensors, shapes, +// etc. +class OpTest : public ::testing::Test { + public: + OpTest(); + + // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs; + // whichever happens first. + void Repeatedly(std::function fn); + + // Select a random element from 'candidates'. + template + T Choose(gtl::ArraySlice candidates); + + static constexpr int kDefaultMaxRank = 5; + static constexpr int64 kDefaultMaxDimensionSize = 20LL; + + // Returns a random dimension size. + int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); + + // Returns a random shape. The tensor has rank in the range [min_rank, + // max_rank). + // Each dimension has size [0, kDefaultMaxDimensionSize]. + std::vector RandomDims(int min_rank = 0, + int max_rank = kDefaultMaxRank, + int64 min_size = 0, + int64 max_size = kDefaultMaxDimensionSize); + + // Given a shape 'dims', build a pair of dimensions such that one broadcasts + // to the other. + std::pair, std::vector> BroadcastableDims( + std::vector dims); + + // Builds a random pair of broadcastable dims. + // TODO(phawkins): currently the maximum rank is 3, because broadcasting > 3 + // dimensions is unimplemented by the Tensorflow Eigen code (b/29268487) + std::pair, std::vector> BroadcastableDims(); + + // Returns a tensor filled with random but "reasonable" values from the middle + // of the type's range. If the shape is omitted, a random shape is used. + // TODO(phawkins): generalize this code to a caller-supplied distribution. + Tensor RandomTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomTensor(DataType dtype); + + // Like RandomTensor, but uses values >= 0. + Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomNonNegativeTensor(DataType dtype); + + // Returns a random subset of the integers in the range [0, rank), suitable + // for use as reduction indices. + Tensor RandomReductionIndices(int rank); + + struct WindowedDims { + Padding padding; + int kernel_rows, kernel_cols; + int stride_rows, stride_cols; + int input_rows, input_cols; + int64 output_rows, output_cols; + }; + // Choose dimensions for a 2D windowed op such as pooling or convolution. + // TODO(phawkins): currently this only produces spatial windows, in NHWC + // format. + WindowedDims ChooseWindowedDims(); + + std::mt19937& generator() { return *generator_; } + + // Run the test case described by 'builder' with and without XLA and check + // that the outputs are close. Tensors x and y are close if they have the same + // type, same shape, and have close values. For floating-point tensors, the + // element-wise difference between x and y must no more than + // atol + rtol * abs(x); or both elements may be NaN or infinity. For + // non-floating-point tensors the element values must match exactly. + void ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, + double atol = 1e-2, double rtol = 1e-2); + + protected: + // Per-test state: + std::unique_ptr generator_; + + std::unique_ptr session_; + + // Number of test cases built in 'session_'. Used to uniquify node names. + int num_tests_ = 0; +}; + +OpTest::OpTest() { + // Creates a random-number generator for the test case. Use the value of + // --tf_xla_random_seed as the seed, if provided. + int64 s = tf_xla_random_seed; + unsigned int seed; + if (s <= 0) { + std::random_device random_device; + seed = random_device(); + } else { + seed = static_cast(s); + } + LOG(INFO) << "Random seed for test case: " << seed + << ". To reproduce the " + "results of this test, pass flag --tf_xla_random_seed=" + << seed; + generator_.reset(new std::mt19937(seed)); + + // Create a session with an empty graph. + SessionOptions session_options; + session_.reset(NewSession(session_options)); + GraphDef def; + TF_CHECK_OK(session_->Create(def)); +} + +void OpTest::Repeatedly(std::function fn) { + int const max_repetitions = tf_xla_test_repetitions; + for (int i = 0; !HasFailure() && i < max_repetitions; ++i) { + fn(); + } +} + +template +T OpTest::Choose(gtl::ArraySlice candidates) { + std::uniform_int_distribution d(0, candidates.size() - 1); + return candidates[d(generator())]; +} + +int64 OpTest::RandomDim(int64 min, int64 max) { + std::uniform_int_distribution size_distribution(min, max - 1); + return size_distribution(generator()); +} + +std::vector OpTest::RandomDims(int min_rank, int max_rank, + int64 min_size, int64 max_size) { + CHECK_LE(0, min_rank); + CHECK_LE(min_rank, max_rank); + std::uniform_int_distribution rank_distribution(min_rank, max_rank); + int rank = rank_distribution(generator()); + std::vector dims(rank); + std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() { + return RandomDim(min_size, max_size); + }); + return dims; +} + +Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { + Tensor tensor(dtype, TensorShape(shape)); + switch (dtype) { + case DT_FLOAT: { + std::uniform_real_distribution distribution(-1.0f, 1.0f); + test::FillFn(&tensor, [this, &distribution](int i) -> float { + return distribution(generator()); + }); + break; + } + case DT_DOUBLE: { + std::uniform_real_distribution distribution(-1.0, 1.0); + test::FillFn(&tensor, [this, &distribution](int i) -> double { + return distribution(generator()); + }); + break; + } + case DT_INT32: { + std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); + test::FillFn(&tensor, [this, &distribution](int i) -> int32 { + return distribution(generator()); + }); + break; + } + case DT_INT64: { + std::uniform_int_distribution distribution(-(1LL << 40), + 1LL << 40); + test::FillFn(&tensor, [this, &distribution](int i) -> int64 { + return distribution(generator()); + }); + break; + } + case DT_BOOL: { + std::bernoulli_distribution distribution; + test::FillFn(&tensor, [this, &distribution](int i) -> bool { + return distribution(generator()); + }); + break; + } + default: + LOG(FATAL) << "Unimplemented type " << dtype << " in RandomTensor"; + } + return tensor; +} + +Tensor OpTest::RandomTensor(DataType dtype) { + return RandomTensor(dtype, RandomDims()); +} + +Tensor OpTest::RandomNonNegativeTensor(DataType dtype, + gtl::ArraySlice shape) { + Tensor tensor(dtype, TensorShape(shape)); + switch (dtype) { + case DT_FLOAT: { + std::uniform_real_distribution distribution(0.0f, 1.0f); + test::FillFn(&tensor, [this, &distribution](int i) -> float { + return distribution(generator()); + }); + break; + } + case DT_DOUBLE: { + std::uniform_real_distribution distribution(0.0, 1.0); + test::FillFn(&tensor, [this, &distribution](int i) -> double { + return distribution(generator()); + }); + break; + } + case DT_INT32: { + std::uniform_int_distribution distribution(0, 1 << 20); + test::FillFn(&tensor, [this, &distribution](int i) -> int32 { + return distribution(generator()); + }); + break; + } + case DT_INT64: { + std::uniform_int_distribution distribution(0, 1LL << 40); + test::FillFn(&tensor, [this, &distribution](int i) -> int64 { + return distribution(generator()); + }); + break; + } + default: + LOG(FATAL) << "Unimplemented type " << dtype + << " in RandomNonNegativeTensor"; + } + return tensor; +} + +Tensor OpTest::RandomNonNegativeTensor(DataType dtype) { + return RandomNonNegativeTensor(dtype, RandomDims()); +} + +std::pair, std::vector> OpTest::BroadcastableDims( + std::vector dims) { + if (dims.empty()) return {dims, dims}; + + // Remove some dimensions from the front of 'dims'. + size_t skip = + std::uniform_int_distribution(0, dims.size() - 1)(generator()); + + std::vector bdims(dims.begin() + skip, dims.end()); + + // Randomly replace some of the remaining dimensions of 'dims' with 1. + std::bernoulli_distribution random_bool; + + for (int64& dim : bdims) { + if (random_bool(generator())) { + dim = 1LL; + } + } + + // Possibly swap the roles of 'dims' and 'bdims'. + if (random_bool(generator())) { + dims.swap(bdims); + } + return {dims, bdims}; +} + +std::pair, std::vector> OpTest::BroadcastableDims() { + return BroadcastableDims(RandomDims(0, 3)); +} + +Tensor OpTest::RandomReductionIndices(int rank) { + std::bernoulli_distribution random_bool; + std::vector indices; + for (int i = 0; i < rank; ++i) { + if (random_bool(generator())) { + indices.push_back(i); + } + } + return test::AsTensor(indices); +} + +OpTest::WindowedDims OpTest::ChooseWindowedDims() { + WindowedDims d; + d.padding = Choose({SAME, VALID}); + std::uniform_int_distribution random_int(1, 5); + Status s; + // Repeatedly try different filter/stride sizes until we find a valid + // combination. + do { + // CPU implementations require stride <= kernel size. + d.kernel_rows = random_int(generator()), + d.input_rows = RandomDim(d.kernel_rows); + d.stride_rows = + std::uniform_int_distribution(1, d.kernel_rows)(generator()); + int64 pad_dummy; + s = GetWindowedOutputSize(d.input_rows, d.kernel_rows, d.stride_rows, + d.padding, &d.output_rows, &pad_dummy); + } while (!s.ok()); + do { + d.kernel_cols = random_int(generator()); + d.input_cols = RandomDim(d.kernel_cols); + d.stride_cols = + std::uniform_int_distribution(1, d.kernel_cols)(generator()); + int64 pad_dummy; + s = GetWindowedOutputSize(d.input_cols, d.kernel_cols, d.stride_cols, + d.padding, &d.output_cols, &pad_dummy); + } while (!s.ok()); + return d; +} + +// Functions for comparing tensors. + +template +bool IsClose(const T& x, const T& y, double atol, double rtol) { + if (std::isnan(x) && std::isnan(y)) return true; + if (x == y) return true; // Allow inf == inf. + return fabs(x - y) < atol + rtol * fabs(x); +} + +template +Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, + double rtol) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (!IsClose(Tx(i), Ty(i), atol, rtol)) { + return errors::InvalidArgument(strings::StrCat( + i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i), + ". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol, + " rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i)))); + } + } + return Status::OK(); +} + +template +Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (Tx(i) != Ty(i)) { + return errors::InvalidArgument(strings::StrCat( + i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), + ". x = ", x.DebugString(), "y = ", y.DebugString())); + } + } + return Status::OK(); +} + +// Tests if "x" and "y" are tensors of the same type, same shape, and with +// close values. For floating-point tensors, the element-wise difference between +// x and y must no more than atol + rtol * abs(x). For non-floating-point +// tensors the values must match exactly. +Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, + double rtol) { + if (a.dtype() != b.dtype()) { + return errors::InvalidArgument(strings::StrCat( + "Tensors have different types: ", DataTypeString(a.dtype()), " and ", + DataTypeString(b.dtype()))); + } + if (!a.IsSameSize(b)) { + return errors::InvalidArgument(strings::StrCat( + "Tensors have different shapes: ", a.shape().DebugString(), " and ", + b.shape().DebugString())); + } + + switch (a.dtype()) { + case DT_FLOAT: + return TensorsAreCloseImpl(a, b, atol, rtol); + case DT_DOUBLE: + return TensorsAreCloseImpl(a, b, atol, rtol); + case DT_INT32: + return TensorsAreEqualImpl(a, b); + case DT_INT64: + return TensorsAreEqualImpl(a, b); + case DT_BOOL: + return TensorsAreEqualImpl(a, b); + default: + LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype()); + } +} + +void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, + double atol, double rtol) { + string cpu_device = DeviceTypeToDeviceName(DEVICE_CPU); + DeviceType test_device_type(*tf_xla_test_device_ptr); + string test_device = DeviceTypeToDeviceName(test_device_type); + ++num_tests_; + + GraphDef graph; + std::vector expected_inputs, test_inputs; + std::vector expected_fetches, test_fetches; + TF_ASSERT_OK(builder.BuildGraph( + strings::StrCat("test", num_tests_, "_expected"), cpu_device, + /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, + &expected_inputs, &expected_fetches)); + + NodeDef* node_def; + TF_ASSERT_OK(builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + test_device, tf_xla_test_use_jit, &graph, + &node_def, &test_inputs, &test_fetches)); + + // Check that there's a kernel corresponding to 'node_def' on the device under + // test. + Status status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr); + if (!status.ok()) { + VLOG(1) << "Skipping test because there is no corresponding registered " + << "kernel on the test device: " << status; + return; + } + + TF_ASSERT_OK(session_->Extend(graph)); + + const std::vector& input_tensors = builder.inputs(); + if (VLOG_IS_ON(1)) { + for (const Tensor& input : input_tensors) { + VLOG(1) << "Input: " << input.DebugString(); + } + } + + std::vector> expected_feeds(expected_inputs.size()); + std::vector> test_feeds(test_inputs.size()); + ASSERT_EQ(input_tensors.size(), expected_inputs.size()); + ASSERT_EQ(input_tensors.size(), test_inputs.size()); + + for (int i = 0; i < input_tensors.size(); ++i) { + expected_feeds[i] = {expected_inputs[i], input_tensors[i]}; + test_feeds[i] = {test_inputs[i], input_tensors[i]}; + } + + std::vector expected_outputs, test_outputs; + VLOG(1) << "Running expected graph"; + Status s = + session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs); + if (!s.ok()) { + VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test"; + return; + } + + VLOG(1) << "Running test graph"; + TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs)); + + ASSERT_EQ(expected_outputs.size(), test_outputs.size()); + for (int j = 0; s.ok() && j < test_outputs.size(); ++j) { + s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol); + } + TF_EXPECT_OK(s); +} + +// Helper that converts 'values' to an int32 or int64 Tensor. +Tensor AsIntTensor(DataType dtype, const std::vector& values) { + switch (dtype) { + case DT_INT32: { + std::vector values32(values.begin(), values.end()); + return test::AsTensor(values32); + } + case DT_INT64: + return test::AsTensor(values); + default: + CHECK(false); + } +} + +TEST_F(OpTest, Abs) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Abs").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, Add) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, AddN) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + int n = std::uniform_int_distribution(1, 5)(generator()); + + auto shape = RandomDims(); + + OpTestBuilder builder("AddN"); + builder.Attr("T", type); + builder.Attr("N", n); + for (int i = 0; i < n; ++i) { + builder.Input(RandomTensor(type, shape)); + } + ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, All) { + Repeatedly([this]() { + Tensor data = RandomTensor(DT_BOOL); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("All").Input(data).Input(indices).Attr("keep_dims", + keep_dims)); + }); +} + +TEST_F(OpTest, Any) { + Repeatedly([this]() { + Tensor data = RandomTensor(DT_BOOL); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Any").Input(data).Input(indices).Attr("keep_dims", + keep_dims)); + }); +} + +TEST_F(OpTest, AvgPool) { + Repeatedly([this]() { + std::uniform_int_distribution random_int(1, 5); + int kernel_rows = random_int(generator()), + kernel_cols = random_int(generator()); + int stride_rows = random_int(generator()), + stride_cols = random_int(generator()); + string padding = Choose({"SAME", "VALID"}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("AvgPool") + .Input( + RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows), + RandomDim(kernel_cols), RandomDim(1)})) + .Attr("T", DT_FLOAT) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride_rows, stride_cols, 1}) + .Attr("padding", padding) + .Attr("data_format", "NHWC")); + }); + // TODO(phawkins): the CPU device only implements spatial pooling. Add tests + // for batch pooling when supported. +} + +TEST_F(OpTest, AvgPoolGrad) { + Repeatedly([this]() { + int batch = RandomDim(1), features = RandomDim(1); + WindowedDims d = ChooseWindowedDims(); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("AvgPoolGrad") + .Input(test::AsTensor( + {batch, d.input_rows, d.input_cols, features})) + .Input(RandomTensor( + DT_FLOAT, {batch, d.output_rows, d.output_cols, features})) + .Attr("T", DT_FLOAT) + .Attr("ksize", {1, d.kernel_rows, d.kernel_cols, 1}) + .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID") + .Attr("data_format", "NHWC")); + }); +} + +TEST_F(OpTest, BatchMatMul) { + Repeatedly([this]() { + std::vector output_dims = RandomDims(2, 5, 0, 7); + int64 ndims = output_dims.size(); + int64 inner_dim = RandomDim(); + std::vector x_dims(output_dims), y_dims(output_dims); + x_dims[ndims - 1] = inner_dim; + y_dims[ndims - 2] = inner_dim; + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") + .Input(RandomTensor(DT_FLOAT, x_dims)) + .Input(RandomTensor(DT_FLOAT, y_dims)) + .Attr("T", DT_FLOAT)); + + std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") + .Input(RandomTensor(DT_FLOAT, x_dims)) + .Input(RandomTensor(DT_FLOAT, y_dims)) + .Attr("T", DT_FLOAT) + .Attr("adj_x", true)); + + std::swap(y_dims[ndims - 1], y_dims[ndims - 2]); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") + .Input(RandomTensor(DT_FLOAT, x_dims)) + .Input(RandomTensor(DT_FLOAT, y_dims)) + .Attr("T", DT_FLOAT) + .Attr("adj_x", true) + .Attr("adj_y", true)); + + std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") + .Input(RandomTensor(DT_FLOAT, x_dims)) + .Input(RandomTensor(DT_FLOAT, y_dims)) + .Attr("T", DT_FLOAT) + .Attr("adj_y", true)); + }); +} + +TEST_F(OpTest, BiasAdd) { + Repeatedly([this]() { + auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); + auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)}); + // TODO(phawkins): test both data formats. + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BiasAdd").Input(x).Input(y).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, BiasAddGrad) { + Repeatedly([this]() { + auto x = RandomTensor(DT_FLOAT); + // TODO(phawkins): test both data formats. + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BiasAddGrad").Input(x).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, BiasAddV1) { + Repeatedly([this]() { + auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); + auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BiasAddV1").Input(x).Input(y).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, BroadcastGradientArgs) { + Repeatedly([this]() { + // TODO(phawkins): only int32 seems to be implemented in Tensorflow. + // DataType type = Choose({DT_INT32, DT_INT64}); + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BroadcastGradientArgs") + .Input(AsIntTensor(type, dims.first)) + .Input(AsIntTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Cast) { + Repeatedly([this]() { + DataType src_type, dst_type; + src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); + dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") + .Input(RandomTensor(src_type)) + .Attr("SrcT", src_type) + .Attr("DstT", dst_type)); + }); +} + +TEST_F(OpTest, Ceil) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Ceil") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Concat) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + int n = std::uniform_int_distribution(2, 5)(generator()); + + std::vector dims = RandomDims(1); + int concat_dim = + std::uniform_int_distribution(0, dims.size() - 1)(generator()); + + OpTestBuilder builder("Concat"); + builder.Input(test::AsScalar(concat_dim)); + builder.Attr("T", type); + builder.Attr("N", n); + for (int i = 0; i < n; ++i) { + std::vector shape = dims; + shape[concat_dim] = RandomDim(); + builder.Input(RandomTensor(type, shape)); + } + ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, ConcatOffset) { + Repeatedly([this]() { + int n = std::uniform_int_distribution(2, 5)(generator()); + + std::vector dims = RandomDims(1); + int concat_dim = + std::uniform_int_distribution(0, dims.size() - 1)(generator()); + + OpTestBuilder builder("ConcatOffset"); + builder.Input(test::AsScalar(concat_dim)); + builder.Attr("N", n); + for (int i = 0; i < n; ++i) { + std::vector shape(dims.begin(), dims.end()); + shape[concat_dim] = RandomDim(); + builder.Input(test::AsTensor(shape)); + } + ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, Conv2D) { + Repeatedly([this]() { + WindowedDims d = ChooseWindowedDims(); + std::uniform_int_distribution random_int(1, 5); + int features_in = random_int(generator()); + int features_out = random_int(generator()); + Tensor data = RandomTensor( + DT_FLOAT, {RandomDim(), d.input_rows, d.input_cols, features_in}); + + Tensor kernel = RandomTensor( + DT_FLOAT, {d.kernel_rows, d.kernel_cols, features_in, features_out}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Conv2D") + .Input(data) + .Input(kernel) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID") + .Attr("data_format", "NHWC")); + }); +} + +TEST_F(OpTest, Conv2DBackpropFilter) { + Repeatedly([this]() { + WindowedDims d = ChooseWindowedDims(); + std::uniform_int_distribution random_int(1, 5); + int features_in = random_int(generator()); + int features_out = random_int(generator()); + int32 batch = RandomDim(); + Tensor activations = RandomTensor( + DT_FLOAT, {batch, d.input_rows, d.input_cols, features_in}); + Tensor backprop = RandomTensor( + DT_FLOAT, {batch, d.output_rows, d.output_cols, features_out}); + Tensor kernel_shape = test::AsTensor( + {d.kernel_rows, d.kernel_cols, features_in, features_out}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Conv2DBackpropFilter") + .Input(activations) + .Input(kernel_shape) + .Input(backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID") + .Attr("data_format", "NHWC")); + }); +} + +TEST_F(OpTest, Conv2DBackpropInput) { + Repeatedly([this]() { + WindowedDims d = ChooseWindowedDims(); + std::uniform_int_distribution random_int(1, 5); + int features_in = random_int(generator()); + int features_out = random_int(generator()); + int32 batch = RandomDim(); + Tensor in_shape = + test::AsTensor({batch, d.input_rows, d.input_cols, features_in}); + Tensor backprop = RandomTensor( + DT_FLOAT, {batch, d.output_rows, d.output_cols, features_out}); + Tensor kernel = RandomTensor( + DT_FLOAT, {d.kernel_rows, d.kernel_cols, features_in, features_out}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Conv2DBackpropInput") + .Input(in_shape) + .Input(kernel) + .Input(backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID") + .Attr("data_format", "NHWC")); + }); +} + +TEST_F(OpTest, Diag) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Diag") + .Input(RandomTensor(type, RandomDims(1))) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, DiagPart) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = RandomDims(1, 3); + // Duplicate the random dims. + std::vector doubled_dims(dims.size() * 2); + std::copy(dims.begin(), dims.end(), doubled_dims.begin()); + std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size()); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart") + .Input(RandomTensor(type, doubled_dims)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Div) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, DynamicStitch) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + int n = std::uniform_int_distribution(2, 5)(generator()); + OpTestBuilder builder("DynamicStitch"); + builder.Attr("T", type); + builder.Attr("N", n); + std::vector> index_dims; + int size = 0; + // TODO(phawkins): the XLA implementation of DynamicStitch does not + // accept an empty set of indices. + do { + size = 0; + index_dims.clear(); + for (int i = 0; i < n; ++i) { + std::vector dims = RandomDims(0, 3, 0, 5); + size += TensorShape(dims).num_elements(); + index_dims.push_back(dims); + } + } while (size == 0); + + // Shuffle the range of indices that cover the output. + // TODO(phawkins): The documentation for DynamicStitch doesn't require that + // the indices cover all positions of the output. The XLA implementation + // does so require. However, the native TF implementation leaves undefined + // values if we don't cover everything, so we can't really test that case + // anyway. + std::vector indices(size); + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), generator()); + + int pos = 0; + for (int i = 0; i < n; ++i) { + TensorShape shape(index_dims[i]); + Tensor t = test::AsTensor( + gtl::ArraySlice(indices, pos, shape.num_elements()), shape); + builder.Input(t); + pos += t.NumElements(); + } + + std::vector constant_dims = RandomDims(0, 3, 0, 5); + for (int i = 0; i < n; ++i) { + std::vector dims(index_dims[i].begin(), index_dims[i].end()); + std::copy(constant_dims.begin(), constant_dims.end(), + std::back_inserter(dims)); + Tensor t = RandomTensor(type, dims); + builder.Input(t); + } + ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, Equal) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Exp) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Exp").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, ExpandDims) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor in = RandomTensor(type); + Tensor dim(DT_INT32, TensorShape()); + std::uniform_int_distribution d(-1 - in.dims(), in.dims()); + dim.scalar()() = d(generator()); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ExpandDims").Input(in).Input(dim).Attr("T", type)); + }); +} + +TEST_F(OpTest, Fill) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor scalar = RandomTensor(type, {}); + std::vector dims = RandomDims(); + std::vector shape(dims.begin(), dims.end()); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Fill") + .Input(test::AsTensor(shape)) + .Input(scalar) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Floor) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Floor") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, FloorDiv) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, FloorMod) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Greater) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, GreaterEqual) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Reciprocal) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reciprocal") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, L2Loss) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + // TODO(b/31644876): scalars currently crash. + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss") + .Input(RandomTensor(type, RandomDims(1))) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Less) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, LessEqual) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, LinSpace) { + Repeatedly([this]() { + auto ToScalar = [](DataType type, int x) { + if (type == DT_INT32) return test::AsScalar(x); + return test::AsScalar(x); + }; + std::uniform_int_distribution distribution(-50, 50); + DataType type = Choose({DT_INT32, DT_INT64}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LinSpace") + .Input(RandomTensor(DT_FLOAT, {})) + .Input(RandomTensor(DT_FLOAT, {})) + .Input(ToScalar(type, distribution(generator()))) + .Attr("T", DT_FLOAT) + .Attr("Tidx", type)); + }); +} + +TEST_F(OpTest, Log) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Log").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, LogicalAnd) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LogicalAnd") + .Input(RandomTensor(DT_BOOL, dims.first)) + .Input(RandomTensor(DT_BOOL, dims.second))); + }); +} + +TEST_F(OpTest, LogicalNot) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LogicalNot").Input(RandomTensor(DT_BOOL))); + }); +} + +TEST_F(OpTest, LogicalOr) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LogicalOr") + .Input(RandomTensor(DT_BOOL, dims.first)) + .Input(RandomTensor(DT_BOOL, dims.second))); + }); +} + +TEST_F(OpTest, LogSoftmax) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LogSoftmax") + .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2))) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, LRN) { + Repeatedly([this]() { + Tensor data; + // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed. + data = RandomTensor(DT_FLOAT, RandomDims(4, 4, 1, 8)); + // CuDNN requires depth_radius > 0. + std::uniform_int_distribution radius(1, data.dim_size(3)); + std::uniform_real_distribution coeff(0.01, 2.0); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRN") + .Input(data) + .Attr("T", DT_FLOAT) + .Attr("depth_radius", radius(generator())) + .Attr("bias", coeff(generator())) + .Attr("alpha", coeff(generator())) + .Attr("beta", coeff(generator()))); + }); +} + +TEST_F(OpTest, LRNGrad) { + Repeatedly([this]() { + // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed. + std::vector dims = RandomDims(4, 4, 1, 8); + Tensor input_grads = RandomTensor(DT_FLOAT, dims); + Tensor input_image = RandomTensor(DT_FLOAT, dims); + Tensor output_image = RandomTensor(DT_FLOAT, dims); + // CuDNN requires depth_radius > 0. + std::uniform_int_distribution radius(1, input_grads.dim_size(3)); + std::uniform_real_distribution coeff(0.0, 2.0); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRNGrad") + .Input(input_grads) + .Input(input_image) + .Input(output_image) + .Attr("T", DT_FLOAT) + .Attr("depth_radius", radius(generator())) + .Attr("bias", coeff(generator())) + .Attr("alpha", coeff(generator())) + .Attr("beta", coeff(generator()))); + }); +} + +TEST_F(OpTest, MatMul) { + Repeatedly([this]() { + int64 x = RandomDim(); + int64 y = RandomDim(); + int64 z = RandomDim(); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") + .Input(RandomTensor(DT_FLOAT, {x, y})) + .Input(RandomTensor(DT_FLOAT, {y, z})) + .Attr("T", DT_FLOAT)); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") + .Input(RandomTensor(DT_FLOAT, {y, x})) + .Input(RandomTensor(DT_FLOAT, {y, z})) + .Attr("T", DT_FLOAT) + .Attr("transpose_a", true)); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") + .Input(RandomTensor(DT_FLOAT, {x, y})) + .Input(RandomTensor(DT_FLOAT, {z, y})) + .Attr("T", DT_FLOAT) + .Attr("transpose_b", true)); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") + .Input(RandomTensor(DT_FLOAT, {y, x})) + .Input(RandomTensor(DT_FLOAT, {z, y})) + .Attr("T", DT_FLOAT) + .Attr("transpose_a", true) + .Attr("transpose_b", true)); + }); +} + +TEST_F(OpTest, MatrixDiag) { + Repeatedly([this]() { + DataType type = Choose({DT_BOOL, DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") + .Input(RandomTensor(type, RandomDims(1))) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, MatrixDiagPart) { + Repeatedly([this]() { + DataType type = Choose({DT_BOOL, DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") + .Input(RandomTensor(type, RandomDims(2))) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Max) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + Tensor data = RandomTensor(type); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Max").Input(data).Input(indices).Attr("T", type).Attr( + "keep_dims", keep_dims)); + }); +} + +TEST_F(OpTest, Maximum) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, MaxPool) { + Repeatedly([this]() { + std::uniform_int_distribution random_int(1, 5); + int kernel_rows = random_int(generator()), + kernel_cols = random_int(generator()); + int stride_rows = random_int(generator()), + stride_cols = random_int(generator()); + string padding = Choose({"SAME", "VALID"}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("MaxPool") + .Input( + RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows), + RandomDim(kernel_cols), RandomDim(1)})) + .Attr("T", DT_FLOAT) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride_rows, stride_cols, 1}) + .Attr("padding", padding) + .Attr("data_format", "NHWC")); + }); + // TODO(phawkins): test NCHW format (not supported by CPU) +} + +TEST_F(OpTest, Mean) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + // TODO(phawkins): CPU and XLA differ output for reducing across a + // size-0 dimension (nan vs 0). For now, require size >= 1. + Tensor data = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 1)); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Mean").Input(data).Input(indices).Attr("T", type).Attr( + "keep_dims", keep_dims)); + }); +} + +TEST_F(OpTest, Min) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + Tensor data = RandomTensor(type); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Min").Input(data).Input(indices).Attr("T", type).Attr( + "keep_dims", keep_dims)); + }); +} + +TEST_F(OpTest, Minimum) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Mod) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Mod") + .Input(RandomTensor(DT_INT32, dims.first)) + .Input(RandomTensor(DT_INT32, dims.second)) + .Attr("T", DT_INT32)); + }); +} + +TEST_F(OpTest, Mul) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Neg) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Neg").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, NotEqual) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Pack) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + int n = std::uniform_int_distribution(1, 5)(generator()); + + std::vector dims = RandomDims(); + int num_dims = dims.size(); + int axis = std::uniform_int_distribution(-num_dims - 1, + num_dims)(generator()); + + OpTestBuilder builder("Pack"); + builder.Attr("T", type); + builder.Attr("N", n); + builder.Attr("axis", axis); + for (int i = 0; i < n; ++i) { + builder.Input(RandomTensor(type, dims)); + } + ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +// TODO(b/31741898): crashes on GPU. +TEST_F(OpTest, Pad) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor t = RandomTensor(type); + + // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. + // DataType tpaddings = Choose({DT_INT32, DT_INT64}); + DataType tpaddings = DT_INT32; + std::vector paddings_vec; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < t.dims(); ++i) { + paddings_vec.push_back(distribution(generator())); + paddings_vec.push_back(distribution(generator())); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec), + TensorShape({t.dims(), 2}))); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Pad").Input(t).Input(paddings).Attr("T", type).Attr( + "Tpaddings", tpaddings)); + }); +} + +TEST_F(OpTest, Pow) { + // TODO(phawkins): Feeding large DT_INT32 values to Pow() leads to + // nontermination. + Repeatedly([this]() { + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Pow") + .Input(RandomTensor(DT_FLOAT, dims.first)) + .Input(RandomTensor(DT_FLOAT, dims.second)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Prod) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + Tensor data = RandomTensor(type); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Prod").Input(data).Input(indices).Attr("T", type).Attr( + "keep_dims", keep_dims)); + }); +} + +TEST_F(OpTest, Range) { + Repeatedly([this]() { + auto ToScalar = [](DataType type, int x) { + if (type == DT_INT32) return test::AsScalar(x); + if (type == DT_INT64) return test::AsScalar(x); + if (type == DT_FLOAT) return test::AsScalar(x); + if (type == DT_DOUBLE) return test::AsScalar(x); + LOG(FATAL) << "Unknown type " << DataTypeString(type); + }; + std::uniform_int_distribution distribution(-50, 50); + DataType tidx = Choose({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Range") + .Input(ToScalar(tidx, distribution(generator()))) + .Input(ToScalar(tidx, distribution(generator()))) + .Input(ToScalar(tidx, distribution(generator()))) + .Attr("Tidx", tidx)); + }); +} + +TEST_F(OpTest, Rank) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Rank").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, RealDiv) { + Repeatedly([this]() { + DataType type = DT_FLOAT; + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Relu) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Relu6) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Relu6Grad) { + Repeatedly([this]() { + auto dims = RandomDims(1); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, ReluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(1); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Reshape) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + std::vector dims = RandomDims(); + std::bernoulli_distribution random_bool; + std::vector dims_before, dims_after; + for (std::vector* out : {&dims_before, &dims_after}) { + std::shuffle(dims.begin(), dims.end(), generator()); + for (int64 dim : dims) { + // Either add the dimension as a new dimension or merge it with the + // previous dimension. + if (out->empty() || random_bool(generator())) { + out->push_back(dim); + } else { + out->back() *= dim; + } + } + } + Tensor data = RandomTensor(type, dims_before); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Reshape") + .Input(data) + .Input(test::AsTensor( + std::vector(dims_after.begin(), dims_after.end()))) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Rsqrt) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Rsqrt") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, RsqrtGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Shape) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Shape").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, ShapeN) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + int n = std::uniform_int_distribution(1, 5)(generator()); + OpTestBuilder builder("ShapeN"); + builder.Attr("T", type); + builder.Attr("N", n); + for (int i = 0; i < n; ++i) { + builder.Input(RandomTensor(type)); + } + ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, Sigmoid) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sigmoid") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SigmoidGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Sign) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sign").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, Size) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Size").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, Slice) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor data = RandomTensor(type); + + std::vector begin(data.dims()), size(data.dims()); + for (int i = 0; i < data.dims(); ++i) { + begin[i] = std::uniform_int_distribution( + 0, data.dim_size(i))(generator()); + size[i] = std::uniform_int_distribution( + -1, data.dim_size(i) - begin[i])(generator()); + } + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice") + .Input(data) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(size)) + .Attr("T", type) + .Attr("Index", DT_INT32)); + }); +} + +TEST_F(OpTest, Softmax) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Softmax") + .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2))) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Split) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + std::vector dims = RandomDims(1); + std::uniform_int_distribution ud; + int32 dim = std::uniform_int_distribution( + 0, static_cast(dims.size()) - 1)(generator()); + int n = std::uniform_int_distribution(1, 5)(generator()); + // Ensure 'dim' is evenly divisible by 'n'. + dims[dim] /= n; + dims[dim] *= n; + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") + .Input(test::AsScalar(dim)) + .Input(RandomTensor(type, dims)) + .Attr("T", type) + .Attr("num_split", n)); + }); +} + +TEST_F(OpTest, Softplus) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Softplus") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SoftplusGrad) { + Repeatedly([this]() { + std::vector dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SparseMatMul) { + Repeatedly([this]() { + int64 x = RandomDim(); + int64 y = RandomDim(); + int64 z = RandomDim(); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") + .Input(RandomTensor(DT_FLOAT, {x, y})) + .Input(RandomTensor(DT_FLOAT, {y, z})) + .Attr("Ta", DT_FLOAT) + .Attr("Tb", DT_FLOAT)); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") + .Input(RandomTensor(DT_FLOAT, {y, x})) + .Input(RandomTensor(DT_FLOAT, {y, z})) + .Attr("Ta", DT_FLOAT) + .Attr("Tb", DT_FLOAT) + .Attr("transpose_a", true)); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") + .Input(RandomTensor(DT_FLOAT, {x, y})) + .Input(RandomTensor(DT_FLOAT, {z, y})) + .Attr("Ta", DT_FLOAT) + .Attr("Tb", DT_FLOAT) + .Attr("transpose_b", true)); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") + .Input(RandomTensor(DT_FLOAT, {y, x})) + .Input(RandomTensor(DT_FLOAT, {z, y})) + .Attr("Ta", DT_FLOAT) + .Attr("Tb", DT_FLOAT) + .Attr("transpose_a", true) + .Attr("transpose_b", true)); + }); +} + +TEST_F(OpTest, Sqrt) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sqrt") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SquaredDifference) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SquaredDifference") + .Input(RandomTensor(DT_FLOAT, dims.first)) + .Input(RandomTensor(DT_FLOAT, dims.second)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Square) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Square").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +TEST_F(OpTest, Squeeze) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor t = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 0, 5)); + std::bernoulli_distribution random_bool; + std::vector squeeze_dims; + for (int i = 0; i < t.dims(); ++i) { + if (t.dim_size(i) == 1 && random_bool(generator())) { + squeeze_dims.push_back(i); + } + } + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze") + .Input(t) + .Attr("squeeze_dims", squeeze_dims) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Sub) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Sum) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + Tensor data = RandomTensor(type); + Tensor indices = RandomReductionIndices(data.dims()); + bool keep_dims = Choose({false, true}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sum").Input(data).Input(indices).Attr("T", type).Attr( + "keep_dims", keep_dims)); + }); +} + +TEST_F(OpTest, StridedSlice) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor data = RandomTensor(type); + + std::vector begin(data.dims()), end(data.dims()); + std::vector strides(data.dims()); + for (int i = 0; i < data.dims(); ++i) { + begin[i] = std::uniform_int_distribution( + -2 * data.dim_size(i), 2 * data.dim_size(i))(generator()); + end[i] = std::uniform_int_distribution( + -2 * data.dim_size(i), 2 * data.dim_size(i))(generator()); + // TODO(b/31360685): support strides other than 1 or -1 + strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1; + } + int64 max_bitmask = (1LL << data.dims()) - 1; + std::uniform_int_distribution bitmask_distribution(0, max_bitmask); + int64 begin_mask = bitmask_distribution(generator()); + int64 end_mask = bitmask_distribution(generator()); + + // Create a ellipsis bitmask with at most one 1 bit set. + int64 ellipsis_mask = 0; + if (data.dims() > 0 && std::bernoulli_distribution()(generator())) { + int ellipsis_pos = + std::uniform_int_distribution(0, data.dims() - 1)(generator()); + ellipsis_mask = 1LL << ellipsis_pos; + } + + int64 new_axis_mask = bitmask_distribution(generator()); + int64 shrink_axis_mask = bitmask_distribution(generator()); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("StridedSlice") + .Input(data) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(end)) + .Input(test::AsTensor(strides)) + .Attr("T", type) + .Attr("Index", DT_INT32) + .Attr("begin_mask", begin_mask) + .Attr("end_mask", end_mask) + .Attr("ellipsis_mask", ellipsis_mask) + .Attr("new_axis_mask", new_axis_mask) + .Attr("shrink_axis_mask", shrink_axis_mask)); + }); +} + +TEST_F(OpTest, StridedSliceGrad) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + + // Dimensions of the forward input. + std::vector dims = RandomDims(); + + std::vector begin(dims.size()), end(dims.size()); + std::vector strides(dims.size()); + for (int i = 0; i < dims.size(); ++i) { + begin[i] = std::uniform_int_distribution(-2 * dims[i], + 2 * dims[i])(generator()); + end[i] = std::uniform_int_distribution(-2 * dims[i], + 2 * dims[i])(generator()); + strides[i] = std::uniform_int_distribution( + -2 * dims[i], 2 * dims[i])(generator()); + } + int64 max_bitmask = (1LL << dims.size()) - 1; + std::uniform_int_distribution bitmask_distribution(0, max_bitmask); + int64 begin_mask = bitmask_distribution(generator()); + int64 end_mask = bitmask_distribution(generator()); + + // Create a ellipsis bitmask with at most one 1 bit set. + int64 ellipsis_mask = 0; + if (!dims.empty() && std::bernoulli_distribution()(generator())) { + int ellipsis_pos = + std::uniform_int_distribution(0, dims.size() - 1)(generator()); + ellipsis_mask = 1LL << ellipsis_pos; + } + + int64 new_axis_mask = bitmask_distribution(generator()); + int64 shrink_axis_mask = bitmask_distribution(generator()); + + // TODO(phawkins): use shape inference for the forward op to compute the + // gradient shape for the backward op. At present, there is a low + // probability of the golden op succeeding. + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("StridedSliceGrad") + .Input(test::AsTensor(dims)) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(end)) + .Input(test::AsTensor(strides)) + .Input(RandomTensor(type, RandomDims(1))) + .Attr("T", type) + .Attr("Index", DT_INT64) + .Attr("begin_mask", begin_mask) + .Attr("end_mask", end_mask) + .Attr("ellipsis_mask", ellipsis_mask) + .Attr("new_axis_mask", new_axis_mask) + .Attr("shrink_axis_mask", shrink_axis_mask)); + }); +} + +TEST_F(OpTest, Tanh) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tanh") + .Input(RandomTensor(DT_FLOAT)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, TanhGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Tile) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor t = RandomTensor(type, RandomDims(1)); + std::vector multiples(t.dims()); + for (int i = 0; i < t.dims(); ++i) { + multiples[i] = std::uniform_int_distribution(1, 3)(generator()); + } + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tile") + .Input(t) + .Input(test::AsTensor(multiples)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Transpose) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + Tensor data = RandomTensor(type); + std::vector perm(data.dims()); + std::iota(perm.begin(), perm.end(), 0); + std::shuffle(perm.begin(), perm.end(), generator()); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") + .Input(data) + .Input(test::AsTensor(perm)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, TruncateDiv) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, TruncateMod) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + auto dims = BroadcastableDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") + .Input(RandomTensor(type, dims.first)) + .Input(RandomTensor(type, dims.second)) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, ZerosLike) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ZerosLike").Input(RandomTensor(type)).Attr("T", type)); + }); +} + +} // anonymous namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU"); + std::vector flag_list = { + tensorflow::Flag( + "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, + "Random seed to use for XLA tests. <= 0 means choose a seed " + "nondetermistically."), + // TODO(phawkins): it might make more sense to run each test up to a + // configurable time bound. + tensorflow::Flag("tf_xla_test_repetitions", + &tensorflow::tf_xla_test_repetitions, + "Number of repetitions for each test."), + tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr, + "Tensorflow device type to use for test"), + tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit, + "Use JIT compilation for the operator under test"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + // XLA devices register kernels at construction time; create and destroy all + // known devices to make sure the kernels are registered. + std::vector devices; + TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( + tensorflow::SessionOptions(), "", &devices)); + for (tensorflow::Device* device : devices) { + delete device; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py new file mode 100644 index 0000000000..efda2cc207 --- /dev/null +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reduction operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class ReduceOpsTest(XLATestCase): + + def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, + rtol=1e-4, atol=1e-4): + """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" + + for test_input in test_inputs: + with self.test_session() as sess: + with self.test_scope(): + a = array_ops.placeholder(dtype) + index = array_ops.placeholder(dtypes.int32) + out = tf_reduce_fn(a, index) + result = sess.run(out, {a: test_input, index: [0]}) + self.assertAllClose(result, np_reduce_fn(test_input, axis=0), + rtol=rtol, atol=atol) + + result = sess.run(out, {a: test_input, index: [1]}) + self.assertAllClose(result, np_reduce_fn(test_input, axis=1), + rtol=rtol, atol=atol) + + result = sess.run(out, {a: test_input, index: [-1]}) + self.assertAllClose(result, np_reduce_fn(test_input, axis=1), + rtol=rtol, atol=atol) + + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, 'Invalid reduction dim'): + sess.run(out, {a: test_input, index: [-33]}) + + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, 'Invalid reduction dim'): + sess.run(out, {a: test_input, index: [2]}) + + FLOAT_DATA = [ + np.zeros(shape=(2, 0)), + np.zeros(shape=(0, 30)), + np.arange(1, 7).reshape(2, 3), + np.arange(-10, -4).reshape(2, 3), + np.arange(-4, 2).reshape(2, 3), + ] + NONEMPTY_FLOAT_DATA = [ + np.arange(1, 7).reshape(2, 3), + np.arange(-10, -4).reshape(2, 3), + np.arange(-4, 2).reshape(2, 3), + ] + BOOL_DATA = [ + np.array([], dtype=np.bool).reshape(2, 0), + np.array([], dtype=np.bool).reshape(0, 3), + np.array([[False, True, False], [True, True, False]]), + ] + + def testReduceSum(self): + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, + self.FLOAT_DATA) + + def testReduceProd(self): + self._testReduction(math_ops.reduce_prod, np.prod, np.float32, + self.FLOAT_DATA) + + def testReduceMin(self): + + def reference_min(inp, axis): + """Wrapper around np.amin that returns +infinity for an empty input.""" + if inp.shape[axis] == 0: + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + return np.amin(inp, axis) + + self._testReduction(math_ops.reduce_min, reference_min, np.float32, + self.FLOAT_DATA) + + def testReduceMax(self): + + def reference_max(inp, axis): + """Wrapper around np.amax that returns -infinity for an empty input.""" + if inp.shape[axis] == 0: + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf')) + return np.amax(inp, axis) + + self._testReduction(math_ops.reduce_max, reference_max, np.float32, + self.FLOAT_DATA) + + def testReduceMean(self): + # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when + # reducing across zero inputs. + self._testReduction(math_ops.reduce_mean, np.mean, np.float32, + self.NONEMPTY_FLOAT_DATA) + + def testReduceAll(self): + self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) + + def testReduceAny(self): + self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py new file mode 100644 index 0000000000..22024f4511 --- /dev/null +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -0,0 +1,110 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for ternary operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class TernaryOpsTest(XLATestCase): + + def _testTernary(self, op, a, b, c, expected): + with self.test_session() as session: + with self.test_scope(): + pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") + pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") + pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c") + output = op(pa, pb, pc) + result = session.run(output, {pa: a, pb: b, pc: c}) + self.assertAllClose(result, expected, rtol=1e-3) + + def testLinspace(self): + self._testTernary( + math_ops.linspace, + np.float32(1), + np.float32(2), + np.int32(1), + expected=np.array([1], dtype=np.float32)) + self._testTernary( + math_ops.linspace, + np.float32(1), + np.float32(4), + np.int32(3), + expected=np.array([1, 2.5, 4], dtype=np.float32)) + + def testRange(self): + self._testTernary( + math_ops.range, + np.int32(1), + np.int32(2), + np.int32(1), + expected=np.array([1], dtype=np.int32)) + self._testTernary( + math_ops.range, + np.int32(1), + np.int32(7), + np.int32(2), + expected=np.array([1, 3, 5], dtype=np.int32)) + + def testSelect(self): + self._testTernary( + array_ops.where, + np.array(0, dtype=np.bool), + np.array(2, dtype=np.float32), + np.array(7, dtype=np.float32), + expected=np.array(7, dtype=np.float32)) + + self._testTernary( + array_ops.where, + np.array([0, 1, 1, 0], dtype=np.bool), + np.array([1, 2, 3, 4], dtype=np.float32), + np.array([5, 6, 7, 8], dtype=np.float32), + expected=np.array([5, 2, 3, 8], dtype=np.float32)) + + self._testTernary( + array_ops.where, + np.array([0, 1, 0], dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), + expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32)) + + def testSlice(self): + for dtype in self.numeric_types: + self._testTernary( + array_ops.slice, + np.array([[], [], []], dtype=dtype), + np.array([1, 0], dtype=np.int32), + np.array([2, 0], dtype=np.int32), + expected=np.array([[], []], dtype=dtype)) + + self._testTernary( + array_ops.slice, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), + np.array([0, 1], dtype=np.int32), + np.array([2, 1], dtype=np.int32), + expected=np.array([[2], [5]], dtype=dtype)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py new file mode 100644 index 0000000000..33e0424e60 --- /dev/null +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -0,0 +1,346 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class UnaryOpsTest(XLATestCase): + """Test cases for unary operators.""" + + def _testUnary(self, op, inp, expected, equality_test=None): + with self.test_session() as session: + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name="a") + output = op(pinp) + result = session.run(output, {pinp: inp}) + if equality_test is None: + equality_test = self.assertAllClose + equality_test(result, expected, rtol=1e-3) + + def ListsAreClose(self, result, expected, rtol): + """Tests closeness of two lists of floats.""" + self.assertEqual(len(result), len(expected)) + for i in range(len(result)): + self.assertAllClose(result[i], expected[i], rtol) + + def testAllTypeOps(self): + for dtype in self.numeric_types: + self._testUnary( + array_ops.diag, + np.array([1, 2, 3, 4], dtype=dtype), + np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) + self._testUnary( + array_ops.diag_part, + np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), + np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) + + self._testUnary( + array_ops.identity, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[-1, 1]], dtype=dtype)) + + self._testUnary( + array_ops.matrix_diag, + np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) + self._testUnary( + array_ops.matrix_diag_part, + np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), + np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype)) + + self._testUnary( + array_ops.prevent_gradient, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[-1, 1]], dtype=dtype)) + + self._testUnary( + array_ops.squeeze, + np.array([[[[[]]]]], dtype=dtype), + expected=np.array([], dtype=dtype)) + self._testUnary( + array_ops.squeeze, + np.array([[[1], [2]]], dtype=dtype), + expected=np.array([1, 2], dtype=dtype)) + self._testUnary( + array_ops.squeeze, + np.array([[[1]], [[2]]], dtype=dtype), + expected=np.array([1, 2], dtype=dtype)) + self._testUnary( + array_ops.squeeze, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + expected=np.array([[1, 2], [3, 4]], dtype=dtype)) + + self._testUnary( + array_ops.stop_gradient, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[-1, 1]], dtype=dtype)) + + def testFloatOps(self): + for dtype in self.float_types: + self._testUnary( + math_ops.ceil, + np.array([[-1.7, 1.2]], dtype=dtype), + expected=np.array([[-1, 2]], dtype=dtype)) + + self._testUnary( + math_ops.exp, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[0.36787945, 2.7182817]], dtype=dtype)) + + self._testUnary( + math_ops.floor, + np.array([[-1.7, 1.2]], dtype=dtype), + expected=np.array([[-2, 1]], dtype=dtype)) + + # Tests for tf.nn ops. + self._testUnary( + nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0)) + + # TODO(b/31644876): enable this test case when fixed. + # self._testUnary(tf.nn.l2_loss, dtype(4), dtype(10)) + + self._testUnary( + nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10)) + + self._testUnary( + math_ops.reciprocal, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[1, 0.5]], dtype=dtype)) + + self._testUnary( + math_ops.log, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0, 0.69314718]], dtype=dtype)) + + self._testUnary( + math_ops.rsqrt, + np.array([[4, 16]], dtype=dtype), + expected=np.array([[0.5, 0.25]], dtype=dtype)) + + self._testUnary( + math_ops.sigmoid, + np.array( + [[1, 1, 1, 1], + [1, 2, 3, 4]], + dtype=dtype), + expected=np.array( + [[0.7310586, 0.7310586, 0.7310586, 0.7310586], + [0.7310586, 0.880797, 0.95257413, 0.98201376]], + dtype=dtype)) + + self._testUnary( + math_ops.sqrt, + np.array([[4, 9]], dtype=dtype), + expected=np.array([[2, 3]], dtype=dtype)) + + self._testUnary( + math_ops.tanh, + np.array( + [[1, 1, 1, 1], + [1, 2, 3, 4]], + dtype=dtype), + expected=np.array( + [[0.76159418, 0.76159418, 0.76159418, 0.76159418], + [0.76159418, 0.96402758, 0.99505478, 0.99932933]], + dtype=dtype)) + + self._testUnary( + nn_ops.log_softmax, + np.array( + [[1, 1, 1, 1], + [1, 2, 3, 4]], + dtype=dtype), + expected=np.array( + [[-1.3862944, -1.3862944, -1.3862944, -1.3862944], + [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], + dtype=dtype)) + + self._testUnary( + nn_ops.relu, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[0, 1]], dtype=dtype)) + + self._testUnary( + nn_ops.relu6, + np.array([[-0.05, 6.05, 5]], dtype=dtype), + expected=np.array([[0, 6, 5]], dtype=dtype)) + + self._testUnary( + nn_ops.softmax, + np.array( + [[1, 1, 1, 1], + [1, 2, 3, 4]], + dtype=dtype), + expected=np.array( + [[0.25, 0.25, 0.25, 0.25], + [0.032058604, 0.087144323, 0.23688284, 0.64391428]], + dtype=dtype)) + + self._testUnary( + nn_ops.softplus, + np.array([[-2, 0, 8]], dtype=dtype), + expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype)) + + def testNumericOps(self): + for dtype in self.numeric_types: + self._testUnary( + math_ops.abs, + np.array([[2, -1]], dtype=dtype), + expected=np.array([[2, 1]], dtype=dtype)) + + self._testUnary( + math_ops.neg, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[1, -1]], dtype=dtype)) + + self._testUnary( + math_ops.square, + np.array([[-2, 3]], dtype=dtype), + expected=np.array([[4, 9]], dtype=dtype)) + + self._testUnary( + array_ops.zeros_like, + np.array([[4, 3], [2, 1]], dtype=dtype), + expected=np.array([[0, 0], [0, 0]], dtype=dtype)) + + def testLogicalOps(self): + self._testUnary( + math_ops.logical_not, + np.array([[True, False], [False, True]], dtype=np.bool), + expected=np.array([[False, True], [True, False]], dtype=np.bool)) + + def testBiasAddGrad(self): + self._testUnary( + gen_nn_ops.bias_add_grad, + np.array([[1., 2.], [3., 4.]], dtype=np.float32), + expected=np.array([4., 6.], dtype=np.float32)) + + self._testUnary(lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), + np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], + dtype=np.float32), + expected=np.array([10., 26.], dtype=np.float32)) + + def testCast(self): + shapes = [[], [4], [2, 3], [2, 0, 4]] + types = [dtypes.bool, dtypes.int32, dtypes.float32] + for shape in shapes: + for src_type in types: + for dst_type in types: + src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype) + src = src.reshape(shape) + + dst = src.astype(dst_type.as_numpy_dtype) + self._testUnary( + lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), + src, + expected=dst) + + def testInvertPermutation(self): + self._testUnary( + array_ops.invert_permutation, + np.array([1, 2, 0], np.int32), + expected=np.array([2, 0, 1], dtype=np.int32)) + + def testRank(self): + rank_op = lambda x: array_ops.rank_internal(x, optimize=False) + for dtype in self.numeric_types: + self._testUnary(rank_op, dtype(7), expected=np.int32(0)) + self._testUnary( + rank_op, np.array( + [[], []], dtype=dtype), expected=np.int32(2)) + self._testUnary( + rank_op, np.array( + [-1, 1], dtype=dtype), expected=np.int32(1)) + self._testUnary( + rank_op, np.array( + [[-1, 1]], dtype=dtype), expected=np.int32(2)) + self._testUnary( + rank_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.int32(2)) + + def testShape(self): + shape_op = lambda x: array_ops.shape_internal(x, optimize=False) + for dtype in self.numeric_types: + self._testUnary(shape_op, dtype(7), expected=np.array([], dtype=np.int32)) + self._testUnary( + shape_op, + np.array([[], []], dtype=dtype), + expected=np.array([2, 0], dtype=np.int32)) + self._testUnary( + shape_op, + np.array([-1, 1], dtype=dtype), + expected=np.array([2], dtype=np.int32)) + self._testUnary( + shape_op, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([1, 2], dtype=np.int32)) + self._testUnary( + shape_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.array([3, 1], dtype=np.int32)) + + def testSize(self): + size_op = lambda x: array_ops.size_internal(x, optimize=False) + for dtype in self.numeric_types: + self._testUnary(size_op, dtype(7), expected=np.int32(1)) + self._testUnary( + size_op, np.array([[], []], dtype=dtype), expected=np.int32(0)) + self._testUnary( + size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2)) + self._testUnary( + size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) + self._testUnary( + size_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.int32(3)) + + def testUnpack(self): + self._testUnary( + array_ops.unpack, + np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), + expected=[ + np.array([1., 2.], dtype=np.float32), + np.array([3., 4.], dtype=np.float32), + np.array([5., 6.], dtype=np.float32), + ], + equality_test=self.ListsAreClose) + + self._testUnary(lambda x: array_ops.unstack(x, axis=1), + np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), + expected=[ + np.array([1., 3., 5.], dtype=np.float32), + np.array([2., 4., 6.], dtype=np.float32), + ], + equality_test=self.ListsAreClose) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py new file mode 100644 index 0000000000..1388a892ba --- /dev/null +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -0,0 +1,81 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for XLA devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class XlaDeviceTest(test.TestCase): + + def testCopies(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + def testLoops(self): + """Tests that loops work on XLA devices.""" + + with session_lib.Session() as session: + x = array_ops.placeholder(dtypes.float32) + with ops.device("device:XLA_CPU:0"): + c = lambda i, _: math_ops.less(i, 5) + b = lambda i, x: (i + 1, x * 2.0 + 1.0) + _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) + + result = session.run(y, {x: np.float32(2)}) + self.assertAllClose(result, np.float32(95), rtol=1e-3) + + def testCond(self): + """Tests that tf.cond works on XLA devices.""" + + with session_lib.Session() as session: + x = array_ops.placeholder(dtypes.float32) + y = array_ops.placeholder(dtypes.float32) + c = array_ops.placeholder(dtypes.bool) + with ops.device("device:XLA_CPU:0"): + z = x + 1.0 + w = control_flow_ops.cond(c, lambda: z, lambda: y) + t = math_ops.add(z, w) + + result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True}) + self.assertAllClose(result, np.float32(6), rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py new file mode 100644 index 0000000000..b72e7c9713 --- /dev/null +++ b/tensorflow/compiler/tests/xla_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Definition of XLA test case.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import re + +from tensorflow.contrib.compiler import jit +from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging + +FLAGS = flags.FLAGS + +flags.DEFINE_string('test_device', None, + 'Tensorflow device on which to place operators under test') +flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') +flags.DEFINE_string('disabled_manifest', None, + 'Path to a file with a list of tests that should not run.') + + +class XLATestCase(test.TestCase): + """XLA test cases are parameterized test cases.""" + + def __init__(self, method_name='runTest'): + super(XLATestCase, self).__init__(method_name) + self.device = FLAGS.test_device + self.has_custom_call = (self.device == 'XLA_CPU') + self.all_tf_types = [ + dtypes.DType(types_pb2.DataType.Value(name)) + for name in FLAGS.types.split(',') + ] + self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] + self.int_types = [ + dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer + ] + self.float_types = [ + dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating + ] + self.numeric_types = self.int_types + self.float_types + + # Parse the manifest file, if any, into a regex identifying tests to + # disable + self.disabled_regex = None + if FLAGS.disabled_manifest is not None: + comments_re = re.compile('#.*$') + manifest_file = open(FLAGS.disabled_manifest, 'r') + lines = manifest_file.read().splitlines() + lines = [comments_re.sub('', l).strip() for l in lines] + self.disabled_regex = re.compile('|'.join(lines)) + manifest_file.close() + + def setUp(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + if self.disabled_regex is not None and self.disabled_regex.match(name): + logging.info('Disabled test case: %s', name) + self.skipTest('{} is disabled by manifest.'.format(name)) + return + logging.info('Start test case: %s', name) + + def tearDown(self): + logging.info('End test case: %s', self._testMethodName) + + @contextlib.contextmanager + def test_session(self): + """Custom implementation of test_session() for XLA tests. + + We override the standard Tensorflow test_session() since it is too + specific to CPU and GPU tests. In particular, we want to disable soft + placement and explicitly assign ops to devices under test. + + Yields: + A session to use when running a test case. + """ + graph = ops.Graph() + with session.Session(graph=graph) as sess, graph.as_default(): + yield sess + + @contextlib.contextmanager + def test_scope(self): + """Test scope that runs tests on a Tensorflow/XLA device. + + Uses a compilation_scope() to mark operators to compile. + + Yields: + A scope to apply to the operators under test. + """ + with ops.device('device:{}:0'.format(self.device)): + yield + + +def Benchmark(tf_bench, builder_fn, use_xla_jit, device): + """Build a graph and run benchmarks against it, with or without XLA. + + Args: + tf_bench: An instance of tf.test.Benchmark, used to run the benchmark. + builder_fn: A function that builds a graph when invoked, and returns + (name, fetches), where name is the name of the test, and fetches + is a list of tensors to fetch as output. + use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF. + device: The tensorflow device to run on, e.g. "cpu", "gpu". + """ + + with ops.Graph().as_default(): + name = None + targets = [] + with ops.device(device): + fetches = [] + jit_scope = jit.experimental_jit_scope + with jit_scope(compile_ops=use_xla_jit): + name, fetches = builder_fn() + + # We only want to benchmark the operations themselves, and not the data + # transfer of the result(s). Non-compiled identity ops ensure XLA + # doesn't know we're dropping the results, otherwise it might compile + # away the entire computation. + for fetch in fetches: + targets.append(array_ops.identity(fetch).op) + + config = config_pb2.ConfigProto(allow_soft_placement=True) + with session.Session(config=config) as sess: + sess.run(variables.global_variables_initializer()) + xla = 'xla_' if use_xla_jit else '' + tf_bench.run_op_benchmark( + sess, targets, name='%s_%s%s' % (name, xla, device)) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD new file mode 100644 index 0000000000..3de9958cd6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/BUILD @@ -0,0 +1,193 @@ +licenses(["notice"]) # Apache 2.0 + +package_group( + name = "internal", + packages = [ + "//tensorflow/compiler/aot/...", + "//tensorflow/compiler/jit/...", + "//tensorflow/compiler/tests/...", + "//tensorflow/compiler/tf2xla/...", + ], +) + +package_group( + name = "friends", + includes = [":internal"], + packages = ["//tensorflow/..."], +) + +package( + default_visibility = [":internal"], +) + +cc_library( + name = "xla_compiler", + srcs = [ + "op_registrations.cc", + "xla_compilation_device.cc", + "xla_compiler.cc", + "xla_context.cc", + "xla_helpers.cc", + "xla_op_kernel.cc", + ], + hdrs = [ + "xla_compilation_device.h", + "xla_compiler.h", + "xla_context.h", + "xla_helpers.h", + "xla_op_kernel.h", + ], + deps = [ + ":common", + ":dump_graph", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:cwise_op", + ], + alwayslink = 1, +) + +cc_library( + name = "common", + srcs = [ + "literal_util.cc", + "shape_util.cc", + "str_util.cc", + "type_util.cc", + ], + hdrs = [ + "literal_util.h", + "shape_util.h", + "str_util.h", + "type_util.h", + ], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# Internal targets below this point. + +cc_test( + name = "str_util_test", + srcs = [ + "str_util_test.cc", + ], + deps = [ + ":common", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "literal_util_test", + srcs = [ + "literal_util_test.cc", + ], + deps = [ + ":common", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "const_analysis", + srcs = ["const_analysis.cc"], + hdrs = ["const_analysis.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "const_analysis_test", + size = "small", + srcs = ["const_analysis_test.cc"], + deps = [ + ":const_analysis", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "xla_local_runtime_context", + hdrs = ["xla_local_runtime_context.h"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/core:framework_lite"], +) + +cc_library( + name = "dump_graph", + srcs = [ + "dump_graph.cc", + "dump_graph_flags.cc", + "dump_graph_flags.h", + ], + hdrs = [ + "dump_graph.h", + ], + deps = [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc new file mode 100644 index 0000000000..e072ef7be7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/const_analysis.h" + +#include +#include + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// Backwards dataflow analysis that finds arguments to a graph that must be +// compile-time constants. +Status BackwardsConstAnalysis(const Graph& g, + std::vector* compile_time_const_args) { + // TODO(phawkins): annotate these on the kernel registrations, rather than + // using a hard-coded list. + // (operator, argument) pairs that must be compile-time constants. + const std::unordered_multimap compile_time_const_inputs = { + {"All", "reduction_indices"}, + {"Any", "reduction_indices"}, + {"ArgMax", "dimension"}, + {"AvgPoolGrad", "orig_input_shape"}, + {"BroadcastGradientArgs", "s0"}, + {"BroadcastGradientArgs", "s1"}, + {"Concat", "concat_dim"}, + {"ConcatV2", "axis"}, + {"ConcatOffset", "concat_dim"}, + {"ConcatOffset", "shape"}, + {"Conv2DBackpropFilter", "filter_sizes"}, + {"Conv2DBackpropInput", "input_sizes"}, + {"DynamicStitch", "indices"}, + {"ExpandDims", "dim"}, + {"Fill", "dims"}, + {"InvertPermutation", "x"}, + {"LinSpace", "start"}, + {"LinSpace", "stop"}, + {"LinSpace", "num"}, + {"Max", "reduction_indices"}, + {"Mean", "reduction_indices"}, + {"Min", "reduction_indices"}, + {"Pad", "paddings"}, + {"Prod", "reduction_indices"}, + {"RandomStandardNormal", "shape"}, + {"RandomUniform", "shape"}, + {"RandomUniformInt", "shape"}, + {"Range", "start"}, + {"Range", "limit"}, + {"Range", "delta"}, + {"Reshape", "shape"}, + {"Slice", "begin"}, + {"Slice", "size"}, + {"Split", "split_dim"}, + {"SplitV", "split_dim"}, + {"SplitV", "size_splits"}, + {"StridedSlice", "begin"}, + {"StridedSlice", "end"}, + {"StridedSlice", "strides"}, + {"StridedSliceGrad", "shape"}, + {"StridedSliceGrad", "begin"}, + {"StridedSliceGrad", "end"}, + {"StridedSliceGrad", "strides"}, + {"Sum", "reduction_indices"}, + {"Tile", "multiples"}, + {"Transpose", "perm"}}; + + // Operators that don't look at the data of their inputs, just the shapes. + const std::unordered_set metadata_ops = { + "Rank", "Shape", "ShapeN", "Size", + }; + + Status status; + std::unordered_set must_be_const; + auto visit = [&status, &metadata_ops, &compile_time_const_inputs, + &must_be_const, compile_time_const_args](Node* node) { + if (!status.ok()) return; + + // If this is a metadata-only op, don't propagate the const requirement. + if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return; + + // If this node must be const, and it isn't a metadata op, then all of its + // parents must be const. + if (must_be_const.find(node) != must_be_const.end()) { + if (node->type_string() == "_Arg") { + int index; + status = GetNodeAttr(node->def(), "index", &index); + if (!status.ok()) return; + compile_time_const_args->at(index) = true; + return; + } + for (Node* pred : node->in_nodes()) { + must_be_const.insert(pred); + } + return; + } + + // Mark any compile-time constant operator arguments as const. + auto range = compile_time_const_inputs.equal_range(node->type_string()); + if (range.first == range.second) return; + + NameRangeMap input_name_ranges; + status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges, + nullptr); + if (!status.ok()) return; + + for (auto it = range.first; it != range.second; ++it) { + auto name_range = input_name_ranges.find(it->second); + if (name_range == input_name_ranges.end()) continue; + + for (Edge const* edge : node->in_edges()) { + if (edge->dst_input() >= name_range->second.first && + edge->dst_input() < name_range->second.second) { + must_be_const.insert(edge->src()); + } + } + } + }; + + // Post-order traversal visits nodes in reverse topological order for an + // acyclic graph. + DFS(g, {}, visit); + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h new file mode 100644 index 0000000000..634b97d7e3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ + +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that +// must be compile-time constants. +Status BackwardsConstAnalysis(const Graph& graph, + std::vector* compile_time_const_args); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc new file mode 100644 index 0000000000..9d125f8d49 --- /dev/null +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for the backward const analysis. + +#include "tensorflow/compiler/tf2xla/const_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(ConstAnalysisTest, Basics) { + Scope root = Scope::NewRootScope(); + + auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(root.WithOpName("Arg3"), DT_INT32, 3); + auto a = ops::Shape(root, arg0); + auto b = ops::Add(root, a, arg1); + auto c = ops::Reshape(root, arg2, b); + auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3)); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(4, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + + // Arg 0 doesn't need to be constant since the graph only uses its shape. + // Arg 1 must be constant because it flows to the shape argument of a Reshape. + // Arg 2 is used only as the value input to a Reshape and need not be const. + // Arg 3 is used as the reduction-indices argument to Sum and must be const. + EXPECT_EQ(const_args, std::vector({false, true, false, true})); +} + +// Regression test for a case where the backward const analysis did +// not visit nodes in topological order. +TEST(ConstAnalysisTest, TopologicalOrder) { + for (bool order : {false, true}) { + Scope root = Scope::NewRootScope(); + + auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2); + auto a = ops::Reshape(root, arg0, arg1); + auto b = ops::Reshape(root, arg2, a); + if (order) { + // Consider both orders for arguments to the Sum so we aren't sensitive + // to the DFS traversal order. + std::swap(a, b); + } + auto c = ops::Add(root, a, b); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(3, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + + EXPECT_EQ(const_args, std::vector({true, true, false})); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc new file mode 100644 index 0000000000..5aa6f806ac --- /dev/null +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for +// debugging. + +#include "tensorflow/compiler/tf2xla/dump_graph.h" + +#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace dump_graph { + +namespace { + +struct NameCounts { + mutex counts_mutex; + std::unordered_map counts; +}; + +string MakeUniquePath(const string& name) { + static NameCounts& instance = *new NameCounts; + int count; + { + mutex_lock lock(instance.counts_mutex); + count = instance.counts[name]++; + } + + legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags(); + string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name); + if (count > 0) { + strings::StrAppend(&path, "_", count); + } + strings::StrAppend(&path, ".pbtxt"); + return path; +} + +} // anonymous namespace + +string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { + string path = MakeUniquePath(name); + TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + return path; +} + +string DumpGraphToFile(const string& name, Graph const& graph, + const FunctionLibraryDefinition* flib_def) { + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + if (flib_def) { + *graph_def.mutable_library() = flib_def->ToProto(); + } + return DumpGraphDefToFile(name, graph_def); +} + +string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { + string path = MakeUniquePath(name); + TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + return path; +} + +} // namespace dump_graph +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.h b/tensorflow/compiler/tf2xla/dump_graph.h new file mode 100644 index 0000000000..bbf01eb90d --- /dev/null +++ b/tensorflow/compiler/tf2xla/dump_graph.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for +// debugging. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ +#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace dump_graph { + +// Dumps 'graph_def' to a file, as a GraphDef text proto. Returns the file name +// chosen. +// +// Automatically picks a file name. Prefixes 'name' with the value of the +// --tf_dump_graph_prefix flag and suffixes it with ".pbtxt" to form a name. +// If a graph has already been dumped by this process with the same name, +// suffixes with "_n.pbtxt", where 'n' is a sequence number. +string DumpGraphDefToFile(const string& name, GraphDef const& graph_def); + +// Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph' +// and an optional function library 'flib_def'. Returns the file name chosen. +string DumpGraphToFile(const string& name, Graph const& graph, + const FunctionLibraryDefinition* flib_def = nullptr); + +// Similar to DumpGraphDefToFile, but dumps a function as a FunctionDef text +// proto. Returns the file name chosen. +string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef); + +} // namespace dump_graph +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc new file mode 100644 index 0000000000..a6c908ba01 --- /dev/null +++ b/tensorflow/compiler/tf2xla/dump_graph_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's dump_graph module. + +#include +#include + +#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static DumpGraphFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new DumpGraphFlags; + flags->tf_dump_graph_prefix = "/tmp/"; + flag_list = new std::vector({ + Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix, + "Path prefix to which graphs dumped during debugging should be " + "written."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with the XLA bridge's +// dump_graph module. +void AppendDumpGraphFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the DumpGraphFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +DumpGraphFlags* GetDumpGraphFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h new file mode 100644 index 0000000000..80a3307d92 --- /dev/null +++ b/tensorflow/compiler/tf2xla/dump_graph_flags.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ + +// Legacy flags for the XLA bridge's dump_graph module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with the XLA bridge's +// dump_graph module. +void AppendDumpGraphFlags(std::vector* flag_list); + +// The values of flags associated with the XLA bridge's +// dump_graph module. +typedef struct { + string tf_dump_graph_prefix; // Path prefix to which graphs dumped during + // debugging should be written. +} DumpGraphFlags; + +// Return a pointer to the DumpGraphFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +DumpGraphFlags* GetDumpGraphFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD new file mode 100644 index 0000000000..d913f898e9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -0,0 +1,177 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], + features = ["no_layering_check"], +) + +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") + +tf_kernel_library( + name = "xla_ops", + srcs = [ + "aggregate_ops.cc", + "batch_matmul_op.cc", + "bcast_ops.cc", + "bias_ops.cc", + "binary_ops.cc", + "cast_op.cc", + "concat_op.cc", + "conv_ops.cc", + "cwise_ops.cc", + "declaration_op.cc", + "depthwise_conv_ops.cc", + "diag_op.cc", + "dynamic_stitch_op.cc", + "fill_op.cc", + "function_ops.cc", + "identity_op.cc", + "l2loss_op.cc", + "lrn_ops.cc", + "matmul_op.cc", + "no_op.cc", + "pack_op.cc", + "pad_op.cc", + "pooling_ops.cc", + "random_ops.cc", + "reduction_ops.cc", + "reduction_ops_common.cc", + "relu_op.cc", + "reshape_op.cc", + "retval_op.cc", + "select_op.cc", + "sequence_ops.cc", + "shape_op.cc", + "slice_op.cc", + "softmax_op.cc", + "split_op.cc", + "strided_slice_op.cc", + "tile_ops.cc", + "transpose_op.cc", + "unary_ops.cc", + "unpack_op.cc", + ], + hdrs = [ + "cwise_ops.h", + "reduction_ops.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:concat_lib", + "//tensorflow/core/kernels:conv_2d", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:depthwise_conv_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:pooling_ops", + "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:transpose_op", + ], +) + +# Kernels that only work on CPU, because they use XLA custom calls. +# Only link this when using the CPU backend for XLA. +# +# TODO(cwhipkey): move into xla_ops when ops can be registered for +# CPU compilation only (b/31363654). +tf_kernel_library( + name = "xla_cpu_only_ops", + srcs = [ + "gather_op.cc", + "index_ops.cc", + ], + deps = [ + ":gather_op_kernel_float_int32", + ":gather_op_kernel_float_int64", + ":index_ops_kernel_argmax_float_1d", + ":index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + +tf_kernel_library( + name = "gather_op_kernel_float_int32", + srcs = ["gather_op_kernel_float_int32.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:gather_functor", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "gather_op_kernel_float_int64", + srcs = ["gather_op_kernel_float_int64.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:gather_functor", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "index_ops_kernel_argmax_float_1d", + srcs = ["index_ops_kernel_argmax_float_1d.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "index_ops_kernel_argmax_float_2d", + srcs = ["index_ops_kernel_argmax_float_2d.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc new file mode 100644 index 0000000000..8f284c3017 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" + +namespace tensorflow { +namespace { + +class AddNOp : public XlaOpKernel { + public: + explicit AddNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + if (!ctx->ValidateInputsAreSameShape(this)) return; + + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("AddN requires at least one argument")); + + xla::ComputationDataHandle sum = ctx->Input(0); + for (int i = 1; i < ctx->num_inputs(); ++i) { + sum = ctx->builder()->Add(sum, ctx->Input(i)); + } + + ctx->SetOutput(0, sum); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(AddNOp); +}; + +REGISTER_XLA_OP("AddN", AddNOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc new file mode 100644 index 0000000000..637360d149 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific BatchMatMul Op. +// The current implementation simply unrolls the computation along the batch +// dimension. +// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" + +namespace tensorflow { +namespace { + +class BatchMatMulOp : public XlaOpKernel { + public: + explicit BatchMatMulOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_x", &adj_x_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_y", &adj_y_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape x_shape = ctx->InputShape(0); + const TensorShape y_shape = ctx->InputShape(1); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(), + errors::InvalidArgument("In[0] and In[1] has different ndims: ", + x_shape.DebugString(), " vs. ", + y_shape.DebugString())); + const int ndims = x_shape.dims(); + OP_REQUIRES( + ctx, ndims >= 2, + errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector dimensions; + int batch_count = 1; + for (int i = 0; i < ndims - 2; ++i) { + OP_REQUIRES( + ctx, x_shape.dim_size(i) == y_shape.dim_size(i), + errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i, + ") must be the same: ", x_shape.DebugString(), + " vs ", y_shape.DebugString())); + dimensions.push_back(x_shape.dim_size(i)); + batch_count *= x_shape.dim_size(i); + } + + int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1); + int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2); + OP_REQUIRES( + ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim), + errors::InvalidArgument( + "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim), + " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(), + " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_)); + + int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2); + int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape.dim_size(x_outer_dim)); + dimensions.push_back(y_shape.dim_size(y_outer_dim)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle x_handle = ctx->Input(0); + xla::ComputationDataHandle y_handle = ctx->Input(1); + + // Reshape input tensors into 3D tensors by flattening the batch + // dimensions. This makes it easier to unroll the batch dimension. + auto x_flat = + builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2), + x_shape.dim_size(ndims - 1)}); + auto y_flat = + builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2), + y_shape.dim_size(ndims - 1)}); + + // Slice batches into individual matrices and multiply them. + std::vector out_slices; + for (int i = 0; i < batch_count; ++i) { + // Slice off individual matrices and reshape to 2D tensors. + auto x_slice = builder->Slice( + x_flat, {i, 0, 0}, + {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); + x_slice = builder->Reshape( + x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); + auto y_slice = builder->Slice( + y_flat, {i, 0, 0}, + {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); + y_slice = builder->Reshape( + y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); + + // Transpose if needed. + auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice; + auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice; + + // Multiply matrices and add an outer singleton dimension to the output + // so we can concatenate along the flattened batch dimension later. + auto out = builder->Dot(lhs, rhs); + out = builder->Reshape(out, + {1, dimensions[ndims - 2], dimensions[ndims - 1]}); + out_slices.push_back(out); + } + + // Concatenate output slices and reshape to original number of dimensions. + xla::ComputationDataHandle data; + if (out_slices.empty()) { + // It is illegal to pass an empty list to ConcatInDim. + // The batch count is empty, so both inputs must have zero elements. + // Arbitrarily use the left input as the argument to Reshape(). + data = x_handle; + } else { + data = builder->ConcatInDim(out_slices, 0); + } + data = builder->Reshape(data, dimensions); + + ctx->SetOutput(0, data); + } + + private: + bool adj_x_; + bool adj_y_; +}; + +REGISTER_XLA_OP("BatchMatMul", BatchMatMulOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc new file mode 100644 index 0000000000..f35835df08 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for broadcasting used in gradient +// code. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { +namespace { + +// Given shapes of two tensors, computes the reduction indices for the +// gradient computation. +// +// TODO(zhifengc): +// 1. Adds support for n-ary (n >= 2). +class BCastGradArgsOp : public XlaOpKernel { + public: + explicit BCastGradArgsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK( + ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32})); + } + + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES( + ctx, ctx->num_inputs() == 2, + errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); + + gtl::InlinedVector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + const TensorShape in_shape = ctx->InputShape(i); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), + errors::InvalidArgument("In[", i, "] must be a vector.", + in_shape.DebugString())); + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal)); + + BCast::Vec vec; + for (int64 i = 0; i < in_shape.num_elements(); ++i) { + vec.push_back(xla::LiteralUtil::Get(literal, {i})); + } + shapes.push_back(vec); + } + BCast bcast(shapes[0], shapes[1]); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "Incompatible shapes: [", str_util::Join(shapes[0], ","), + "] vs. [", str_util::Join(shapes[1], ","), "]")); + Output(ctx, 0, bcast.grad_x_reduce_idx()); + Output(ctx, 1, bcast.grad_y_reduce_idx()); + } + + private: + void Output(XlaOpKernelContext* ctx, int idx, const BCast::Vec& v) { + const int64 len = v.size(); + Tensor constant(DT_INT32, TensorShape({len})); + for (int64 i = 0; i < len; ++i) { + constant.flat()(i) = static_cast(v[i]); + } + ctx->SetConstantOutput(idx, constant); + } + + TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); +}; + +REGISTER_XLA_OP("BroadcastGradientArgs", BCastGradArgsOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc new file mode 100644 index 0000000000..217e82304e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +class BiasOp : public XlaOpKernel { + public: + explicit BiasOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format; + if (ctx->GetAttr("data_format", &data_format).ok()) { + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } else { + data_format_ = FORMAT_NHWC; + } + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape bias_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument("Input tensor must be at least 2D: ", + input_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_shape), + errors::InvalidArgument("Biases must be 1D: ", + bias_shape.DebugString())); + int feature_dim = (data_format_ == FORMAT_NHWC) ? input_shape.dims() - 1 + : input_shape.dims() - 3; + OP_REQUIRES( + ctx, feature_dim >= 0, + errors::InvalidArgument("Input tensor does not have enough dimensions " + "to contain the feature dimension")); + OP_REQUIRES( + ctx, bias_shape.dim_size(0) == input_shape.dim_size(feature_dim), + errors::InvalidArgument( + "Must provide as many biases as the last dimension " + "of the input tensor: ", + bias_shape.DebugString(), " vs. ", input_shape.DebugString())); + + xla::ComputationDataHandle result = + ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); + ctx->SetOutput(0, result); + } + + private: + TensorFormat data_format_; +}; + +REGISTER_XLA_OP("BiasAdd", BiasOp); +REGISTER_XLA_OP("BiasAddV1", BiasOp); + +class BiasAddGradOp : public XlaOpKernel { + public: + explicit BiasAddGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format; + if (ctx->GetAttr("data_format", &data_format).ok()) { + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } else { + data_format_ = FORMAT_NHWC; + } + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape out_backprop_shape = ctx->InputShape(0); + + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrixOrHigher(out_backprop_shape), + errors::InvalidArgument("Input tensor must be at least 2D: ", + out_backprop_shape.DebugString())); + + int feature_dim = (data_format_ == FORMAT_NHWC) + ? out_backprop_shape.dims() - 1 + : out_backprop_shape.dims() - 3; + OP_REQUIRES( + ctx, feature_dim >= 0, + errors::InvalidArgument("Input tensor does not have enough dimensions " + "to contain the feature dimension")); + + std::vector reduce_dims(out_backprop_shape.dims() - 1); + std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); + std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), + feature_dim + 1); + xla::ComputationDataHandle result = ctx->builder()->Reduce( + ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)), + *ctx->GetOrCreateAdd(input_type(0)), reduce_dims); + ctx->SetOutput(0, result); + } + + private: + TensorFormat data_format_; +}; + +REGISTER_XLA_OP("BiasAddGrad", BiasAddGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc new file mode 100644 index 0000000000..6f117ebe61 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -0,0 +1,158 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of simple unary Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +// A subclass of a XlaBinaryOp must build the computation that +// describes the (tensor,tensor)->tensor function to apply to each element of +// the input. +#define XLA_MAKE_BINARY(Name, HLO) \ + class Name##Op : public XlaBinaryOp { \ + public: \ + explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ + xla::ComputationDataHandle Computation( \ + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \ + const gtl::ArraySlice& lhs_shape, \ + const xla::ComputationDataHandle& rhs, \ + const gtl::ArraySlice& rhs_shape, \ + const BCast& broadcast_helper, \ + const std::vector& extend_dimensions) override { \ + xla::ComputationBuilder* b = ctx->builder(); \ + return HLO; \ + } \ + }; \ + REGISTER_XLA_OP(#Name, Name##Op) + +XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); + +// Implementation of FloorDiv. Pseudo-code: +// if ((x < 0) != (y < 0)) { +// T abs_x = std::abs(x); +// T abs_y = std::abs(y); +// return -(abs_x + abs_y - 1) / abs_y; +// } else { +// return x / y; +// } +static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, + DataType dtype, + xla::ComputationDataHandle x, + xla::ComputationDataHandle y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto one = XlaHelpers::One(b, dtype); + auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero)); + auto abs_x = b->Abs(x); + auto abs_y = b->Abs(y); + auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); + auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); + if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + result = b->Floor(result); + } + return result; +} +XLA_MAKE_BINARY(FloorDiv, + FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +// Implementation of FloorMod. Pseudo-code: +// T trunc_mod = std::fmod(x, y); +// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); +static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, + DataType dtype, + xla::ComputationDataHandle x, + xla::ComputationDataHandle y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); + auto trunc_mod = b->Rem(x, y); + return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y)); +} +XLA_MAKE_BINARY(FloorMod, + FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +XLA_MAKE_BINARY(LogicalAnd, b->LogicalAnd(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, b->LogicalOr(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY( + RsqrtGrad, + b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), + extend_dimensions)); + +static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& x) { + return builder->Mul(x, x); +} + +XLA_MAKE_BINARY(SquaredDifference, + Square(b, b->Sub(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions)); + +// Comparison ops +XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); + +#undef XLA_MAKE_BINARY + +#define XLA_MAKE_BINARY_MAP(Name, HLO) \ + class Name##Op : public XlaBinaryMapOp { \ + public: \ + explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \ + void BuildMapLambda(xla::ComputationBuilder* b, \ + const xla::ComputationDataHandle& lhs, \ + const xla::ComputationDataHandle& rhs) override { \ + HLO; \ + } \ + }; \ + REGISTER_XLA_OP(#Name, Name##Op) + +XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs)); +XLA_MAKE_BINARY_MAP(SigmoidGrad, + b->Mul(b->Mul(rhs, lhs), + b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); +XLA_MAKE_BINARY_MAP(SoftplusGrad, + b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); +XLA_MAKE_BINARY_MAP(TanhGrad, + b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(lhs, lhs)))); + +#undef XLA_MAKE_BINARY_MAP + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc new file mode 100644 index 0000000000..b0188b4f8d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class CastOp : public XlaOpKernel { + public: + explicit CastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle output; + + if (src_dtype_ == dst_dtype_) { + output = input; + } else if (src_dtype_ == DT_BOOL) { + // XLA's ConvertElementType doesn't support casting to/from + // bools. So we need to handle those cases separately. + // Builds the equivalent of (input ? 1 : 0) + xla::ComputationBuilder l(builder->client(), "PredCast"); + xla::ComputationDataHandle x = + l.Parameter(0, xla::ShapeUtil::MakeShape(src_type_, {}), "x"); + l.Select(x, XlaHelpers::One(&l, dst_dtype_), + XlaHelpers::Zero(&l, dst_dtype_)); + xla::Computation computation = l.Build().ConsumeValueOrDie(); + output = builder->Map({input}, computation); + } else if (dst_dtype_ == DT_BOOL) { + output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + } else { + output = builder->ConvertElementType(input, dst_type_); + } + + ctx->SetOutput(0, output); + } + + protected: + DataType src_dtype_, dst_dtype_; + xla::PrimitiveType src_type_, dst_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(CastOp); +}; + +REGISTER_XLA_OP("Cast", CastOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc new file mode 100644 index 0000000000..96ef2ac20c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Concat Ops. + +#include +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// -------------------------------------------------------------------------- +class ConcatBaseOp : public XlaOpKernel { + public: + ConcatBaseOp(OpKernelConstruction* c, int axis_index) + : XlaOpKernel(c), axis_index_(axis_index) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); + OP_REQUIRES( + ctx, IsLegacyScalar(concat_dim_tensor_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar integer, but got shape ", + concat_dim_tensor_shape.DebugString())); + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); + // TODO(annarev): add a helper to support int64 input. + const int32 concat_dim = xla::LiteralUtil::Get(literal, {}); + + std::vector values; + std::vector shapes; + OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); + const int N = values.size(); + const int input_dims = shapes[0].dims(); + const TensorShape& input_shape = shapes[0]; + + int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; + OP_REQUIRES(ctx, + (0 <= axis && axis < input_dims) || + (allow_legacy_scalars() && concat_dim == 0), + errors::InvalidArgument( + "ConcatOp : Expected concatenating dimensions in the range " + "[", + -input_dims, ", ", input_dims, "), but got ", concat_dim)); + + // Make a vector holding the ComputationDataHandles for each of + // the inputs that has non-zero elements. + std::vector input_data; + int output_concat_dim = 0; + const bool input_is_scalar = IsLegacyScalar(input_shape); + for (int i = 0; i < N; ++i) { + xla::ComputationDataHandle handle = values[i]; + const TensorShape& in_shape = shapes[i]; + const bool in_is_scalar = IsLegacyScalar(in_shape); + OP_REQUIRES( + ctx, + in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), + errors::InvalidArgument( + "ConcatOp : Ranks of all input tensors should match: shape[0] = ", + input_shape.DebugString(), " vs. shape[", i, "] = ", + in_shape.DebugString())); + if (in_shape.dims() == 0) { + // Inputs that come in as scalars must be reshaped to 1-vectors. + input_data.push_back(ctx->builder()->Reshape(handle, {1})); + } else { + input_data.push_back(handle); + } + output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + } + + VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; + ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis)); + } + + private: + int axis_index_; +}; + +class ConcatOp : public ConcatBaseOp { + public: + explicit ConcatOp(OpKernelConstruction* c) + : ConcatBaseOp(c, /* axis_index */ 0) {} +}; + +// ConcatV2 operation is the same as Concat except 'concat_dim' +// is the last input instead of the first and renamed to 'axis'. +class ConcatV2Op : public ConcatBaseOp { + public: + explicit ConcatV2Op(OpKernelConstruction* c) + : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} +}; + +REGISTER_XLA_OP("Concat", ConcatOp); +REGISTER_XLA_OP("ConcatV2", ConcatV2Op); + +class ConcatOffsetOp : public XlaOpKernel { + public: + explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape concat_dim_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, IsLegacyScalar(concat_dim_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar integer, but got shape ", + concat_dim_shape.DebugString())); + for (int i = 1; i < ctx->num_inputs(); ++i) { + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), + errors::InvalidArgument("input ", i, + " should be a vector, but got shape ", + ctx->InputShape(i).DebugString())); + } + // Suppose a Concat() op needs to Concatenate N tensors, each of + // which has the same number of dimensions. Their shapes match + // except the concat dimension. + // + // E.g., say, we want to concatenate 3 tensors in the 2nd + // dimension, and their shapes are: + // + // [2, 2, 5, 7] + // [2, 3, 5, 7] + // [2, 4, 5, 7] + // + // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape + // [2,9,5,7]. We will compute the cumulative sum along the 2nd + // dimension to figure out each input's offset in the concatenated + // output: + // [0, 0, 0, 0] + // [0, 2, 0, 0] + // [0, 5, 0, 0] + const int32 N = ctx->num_inputs() - 1; + const TensorShape inp0_shape = ctx->InputShape(1); + xla::Literal inp0_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); + const int64 dims = inp0_shape.num_elements(); + + xla::Literal concat_dim_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); + const int64 cdim = xla::LiteralUtil::Get(concat_dim_literal, {}); + + VLOG(1) << "ConcatOffset " << cdim << "," << dims; + int32 axis = cdim < 0 ? cdim + dims : cdim; + OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), + errors::InvalidArgument("Concat dim is out of range: ", axis, + " vs. ", dims)); + int32 offset = 0; + for (int i = 0; i < N; ++i) { + const TensorShape inp_shape = ctx->InputShape(1 + i); + OP_REQUIRES(ctx, dims == inp_shape.num_elements(), + errors::InvalidArgument("input ", i, " should contain ", dims, + " elements, but got", + inp_shape.num_elements())); + xla::Literal inp_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); + + Tensor out_constant(DT_INT32, TensorShape({dims})); + auto out_vec = out_constant.vec(); + for (int64 j = 0; j < dims; ++j) { + if (j == axis) { + out_vec(j) = offset; + offset += xla::LiteralUtil::Get(inp_literal, {j}); + } else { + const int32 inp0_element = + xla::LiteralUtil::Get(inp0_literal, {j}); + const int32 inp_element = + xla::LiteralUtil::Get(inp_literal, {j}); + OP_REQUIRES( + ctx, (inp0_element == inp_element), + errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", + inp0_element, " vs. ", inp_element)); + out_vec(j) = 0; + } + } + + ctx->SetConstantOutput(i, out_constant); + } + } +}; + +REGISTER_XLA_OP("ConcatOffset", ConcatOffsetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc new file mode 100644 index 0000000000..9bebfcfe47 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -0,0 +1,373 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D convolution. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +class Conv2DOp : public XlaOpKernel { + public: + explicit Conv2DOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(ctx, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); + const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + ctx, stride_n == 1 && stride_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + const TensorShape filter_shape = ctx->InputShape(1); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + OP_REQUIRES(ctx, filter_shape.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter_shape.DebugString())); + + // The 'C' dimension for input is in_depth. It must be the same as + // the filter's in_depth. + const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); + OP_REQUIRES( + ctx, in_depth == filter_shape.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter_shape.dim_size(2))); + + // The last dimension for filter is out_depth. + const int64 out_depth = filter_shape.dim_size(3); + + // The 'H' dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows = GetTensorDim(input_shape, data_format_, 'H'); + const int64 filter_rows = filter_shape.dim_size(0); + + // The 'W' dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols = GetTensorDim(input_shape, data_format_, 'W'); + const int64 filter_cols = filter_shape.dim_size(1); + + // For now we take the stride from the H and W dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(ctx, + GetWindowedOutputSize(input_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(ctx, + GetWindowedOutputSize(input_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_cols)); + + VLOG(2) << "Conv2D: in_depth = " << in_depth + << ", input_cols = " << input_cols + << ", filter_cols = " << filter_cols + << ", input_rows = " << input_rows + << ", filter_rows = " << filter_rows + << ", stride_rows = " << stride_rows + << ", stride_cols = " << stride_cols + << ", out_depth = " << out_depth; + + xla::ConvolutionDimensionNumbers dims; + dims.set_batch_dimension(GetTensorDimIndex<2>(data_format_, 'N')); + dims.set_feature_dimension(GetTensorDimIndex<2>(data_format_, 'C')); + dims.add_spatial_dimensions(GetTensorDimIndex<2>(data_format_, 'H')); + dims.add_spatial_dimensions(GetTensorDimIndex<2>(data_format_, 'W')); + + // TF filter shape is [ H, W, inC, outC ] + dims.add_kernel_spatial_dimensions(0); + dims.add_kernel_spatial_dimensions(1); + dims.set_kernel_input_feature_dimension(2); + dims.set_kernel_output_feature_dimension(3); + + std::vector window_strides = {stride_rows, stride_cols}; + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + xla::ComputationDataHandle conv = ctx->builder()->ConvWithGeneralDimensions( + ctx->Input(0), ctx->Input(1), window_strides, xla_padding, dims); + ctx->SetOutput(0, conv); + } + + private: + std::vector strides_; + Padding padding_; + TensorFormat data_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); +}; + +REGISTER_XLA_OP("Conv2D", Conv2DOp); + +// Backprop for input. +class Conv2DBackpropInputOp : public XlaOpKernel { + public: + explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + OP_REQUIRES(ctx, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + ctx, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape input_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); + + const TensorShape filter_shape = ctx->InputShape(1); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // Reuse dimension computation logic from conv_grad_ops.cc. + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK( + ctx, Conv2DBackpropComputeDimensions( + "Conv2DBackpropInput", input_shape, filter_shape, + out_backprop_shape, strides_, padding_, data_format_, &dims)); + + auto filter = ctx->Input(1); + auto out_backprop = ctx->Input(2); + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(GetTensorDimIndex(data_format_, 'N')); + dnums.add_spatial_dimensions(GetTensorDimIndex(data_format_, 'H')); + dnums.add_spatial_dimensions(GetTensorDimIndex(data_format_, 'W')); + dnums.set_feature_dimension(GetTensorDimIndex(data_format_, 'C')); + + // TF filter shape is [ H, W, inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(3); + dnums.set_kernel_output_feature_dimension(2); + + // Mirror the filter in the spatial dimensions. + xla::ComputationDataHandle mirrored_weights = + ctx->builder()->Rev(filter, {dnums.kernel_spatial_dimensions(0), + dnums.kernel_spatial_dimensions(1)}); + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + xla::ComputationDataHandle in_backprop = ctx->builder()->ConvGeneralDilated( + out_backprop, mirrored_weights, /*window_strides=*/{1, 1}, + /*padding=*/{{dims.rows.pad_before, dims.rows.pad_after}, + {dims.cols.pad_before, dims.cols.pad_after}}, + /*lhs_dilation=*/{dims.rows.stride, dims.cols.stride}, + /*rhs_dilation=*/{1, 1}, dnums); + + ctx->SetOutput(0, in_backprop); + } + + private: + std::vector strides_; + Padding padding_; + TensorFormat data_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp); +}; + +class Conv2DBackpropFilterOp : public XlaOpKernel { + public: + explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + ctx, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape activations_shape = ctx->InputShape(0); + TensorShape filter_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // Reuse dimension computation logic from conv_grad_ops.cc. + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK( + ctx, Conv2DBackpropComputeDimensions( + "Conv2DBackpropFilter", activations_shape, filter_shape, + out_backprop_shape, strides_, padding_, data_format_, &dims)); + + xla::ComputationDataHandle activations = ctx->Input(0); + xla::ComputationDataHandle gradients = ctx->Input(2); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + const int n_dim = GetTensorDimIndex(data_format_, 'N'); + const int h_dim = GetTensorDimIndex(data_format_, 'H'); + const int w_dim = GetTensorDimIndex(data_format_, 'W'); + const int c_dim = GetTensorDimIndex(data_format_, 'C'); + + // Swap n_dim and c_dim in the activations. + dnums.set_batch_dimension(c_dim); + dnums.add_spatial_dimensions(h_dim); + dnums.add_spatial_dimensions(w_dim); + dnums.set_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, out_depth] where + // the batch becomes the input feature for the convolution. + dnums.add_kernel_spatial_dimensions(h_dim); + dnums.add_kernel_spatial_dimensions(w_dim); + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int padded_in_rows = + dims.rows.expanded_output_size + dims.rows.filter_size - 1; + const int padded_in_cols = + dims.cols.expanded_output_size + dims.cols.filter_size - 1; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int total_pad_in_rows = padded_in_rows - dims.rows.input_size; + const int total_pad_in_cols = padded_in_cols - dims.cols.input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int top_pad_in_rows = + (total_pad_in_rows > 0 && padding_ == Padding::SAME) + ? total_pad_in_rows / 2 + : 0; + const int left_pad_in_cols = + (total_pad_in_cols > 0 && padding_ == Padding::SAME) + ? total_pad_in_cols / 2 + : 0; + + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + auto filter_backprop = ctx->builder()->ConvGeneralDilated( + activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{top_pad_in_rows, total_pad_in_rows - top_pad_in_rows}, + {left_pad_in_cols, total_pad_in_cols - left_pad_in_cols}}, + /*lhs_dilation=*/{1, 1}, + /*rhs_dilation=*/{dims.rows.stride, dims.cols.stride}, dnums); + + // The layout of filter_backprop will match the layout of + // padded_activations + // and so will have layout: [out_feature, h, w, in_feature] + // Tensorflow filter shape is [ H, W, inC, outC ], so we transpose the + // output. + xla::ComputationDataHandle filter_backprop_reshaped = + ctx->builder()->Transpose(filter_backprop, + {h_dim, w_dim, c_dim, n_dim}); + ctx->SetOutput(0, filter_backprop_reshaped); + } + + private: + std::vector strides_; + Padding padding_; + TensorFormat data_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropFilterOp); +}; + +REGISTER_XLA_OP("Conv2DBackpropInput", Conv2DBackpropInputOp); +REGISTER_XLA_OP("Conv2DBackpropFilter", Conv2DBackpropFilterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc new file mode 100644 index 0000000000..3cd0b39c87 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -0,0 +1,177 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific base classes for Unary and Binary Ops. + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { + const TensorShape lhs_shape = ctx->InputShape(0); + const TensorShape rhs_shape = ctx->InputShape(1); + + // By TensorFlow conventions the inputs may not have the same + // shapes, in which case they will be automatically broadcast if + // possible before mapping. Use the standard TensorFlow helper to + // compute valid broadcast shapes, but rely below on XLA to + // automatically perform the broadcast assuming its valid shapes are + // a superset of TensorFlow's valid shapes. + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + if (!bcast.IsValid()) { + ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", + lhs_shape.DebugString(), " vs. ", + rhs_shape.DebugString())); + return; + } + TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); + + // Fetch the expressions containing the input tensors. + auto lhs_handle = ctx->Input(0); + auto rhs_handle = ctx->Input(1); + + // If the ranks of the inputs don't match, TensorFlow automatically + // reshapes the smaller by padding with dimensions of size 1 as a + // prefix. In other words to pad a 5-vector to a 3-dimensional + // tensor it is reshaped to have shape [1,1,5]. XLA's automatic + // broadcast code is able to broadcast from lower to higher rank, + // but doesn't assume you want to pad as a prefix of the dimensions, + // and instead needs to be told which dimensions of the higher rank + // tensor to match to the lower rank tensor. In this example it + // would be dimensions [2]. If we were matching a matrix against a + // 4-D tensor the dimensions to match would be [2,3], + // etc. extend_dimension encodes the general case. + std::vector extend_dimension; + int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims()); + int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims()); + if (min_rank != max_rank) { + for (int i = 0; i < min_rank; ++i) { + // Match the lower rank tensor along the larger-numbered + // dimensions of the higher rank tensor. + extend_dimension.push_back(max_rank - min_rank + i); + } + } + + // Call virtual method to emit the computation. + xla::ComputationDataHandle output = + Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle, + rhs_shape.dim_sizes(), bcast, extend_dimension); + + // The TensorFlow helper computed the post-broadcast shape in + // output_shape: we rely on subclassed Computations to implement the + // same broadcast semantics. + ctx->SetOutput(0, output); +} + +/* static */ std::pair +XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& lhs, + const xla::ComputationDataHandle& rhs, + const BCast& broadcast_helper) { + // Manually construct the broadcasting since MapN does not do + // automatic broadcasting. The bcast helper ensures that + // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and + // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have + // the same shape, so can be operated on by MapN. + + // First reshape the inputs, which should be a metadata-only + // operation since we are flattening the dimensions in order. + auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape()); + auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape()); + + // Next broadcast the necessary input dimensions. We rely on the + // XLA optimizer to be smart about the fact that we are asking + // it to broadcast size 1 on some of these dimensions, to avoid + // adding complexity to this code. + auto lhs_broadcast = + builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast()); + int lhs_size = broadcast_helper.x_bcast().size(); + auto rhs_broadcast = + builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast()); + int rhs_size = broadcast_helper.y_bcast().size(); + + // Now reshape them to the correct output shape. After the + // broadcast each side is twice as wide as it should be, since the + // broadcast dimensions were prepended to the shape. Reshape + // flattening each original dimension with the prepended broadcast + // dimension. E.g. if we started out with lhs_shaped with shape + // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have + // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. + std::vector lhs_reorder; + for (int i = 0; i < lhs_size; ++i) { + lhs_reorder.push_back(i); + lhs_reorder.push_back(i + lhs_size); + } + auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder, + broadcast_helper.output_shape()); + std::vector rhs_reorder; + for (int i = 0; i < rhs_size; ++i) { + rhs_reorder.push_back(i); + rhs_reorder.push_back(i + rhs_size); + } + auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder, + broadcast_helper.output_shape()); + + return {lhs_output, rhs_output}; +} + +xla::ComputationDataHandle XlaBinaryMapOp::Computation( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, + const gtl::ArraySlice& lhs_shape, + const xla::ComputationDataHandle& rhs, + const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const std::vector& extend_dimensions) { + xla::ComputationBuilder* builder = ctx->builder(); + + // Construct the builder for the lambda computation. + xla::ComputationBuilder l(builder->client(), ctx->op_kernel().name()); + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); + + // Make two scalar parameters of the desired type for the lambda. + xla::ComputationDataHandle x = + l.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + xla::ComputationDataHandle y = + l.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); + + // Call virtual method to build the lambda. + BuildMapLambda(&l, x, y); + xla::Computation computation = l.Build().ConsumeValueOrDie(); + + xla::ComputationDataHandle lhs_broadcast = lhs; + xla::ComputationDataHandle rhs_broadcast = rhs; + if (lhs_shape == rhs_shape) { + // There's no broadcasting to do. + CHECK_EQ(0, extend_dimensions.size()); + return builder->Map({lhs, rhs}, computation); + } else { + std::tie(lhs_broadcast, rhs_broadcast) = + Broadcast(builder, lhs, rhs, broadcast_helper); + } + // Now the two sides are broadcast to the final shape we can do the map. + return builder->Map({lhs_broadcast, rhs_broadcast}, computation); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h new file mode 100644 index 0000000000..f0687c1d4b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific base classes for Unary and Binary Ops. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +// Coefficient-wise binary operations. Each binary Op expects two +// inputs that can be broadcast to the same shape. The base class +// contains pure virtual methods to override: description is a textual +// description of the operation; and Computation adds the +// implementation of the operation to a xla::ComputationBuilder. For most +// arithmetic Ops XLA handles the broadcasting automatically given the input +// tensors. Ops like ReluGrad that need to map a scalar function over the inputs +// can use the XlaBinaryMapOp subclass below which handles manual +// broadcasting of the inputs. +class XlaBinaryOp : public XlaOpKernel { + public: + explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const DataType lhs = BaseType(input_type(0)); + const DataType rhs = BaseType(input_type(1)); + OP_REQUIRES(ctx, lhs == rhs, + errors::InvalidArgument("Input types of binary op must match")); + } + ~XlaBinaryOp() override {} + + // Implement the (tensor,tensor)->tensor lambda that should be + // applied to the inputs. The desired computation should be added to + // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and + // (lhs_shape,rhs_shape) are their respective + // shapes. 'broadcast_helper' contains metadata about the shapes of + // the inputs and the dimensions that need to be broadcast, which + // may be useful for Ops that can't use standard XLA automatic + // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have + // different ranks, and indicates which dimensions of the + // higher-rank input should be matched when broadcasting the + // lower-rank input. See comment below and the documentation on broadcasting + // in the XLA documentation. + virtual xla::ComputationDataHandle Computation( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, + const gtl::ArraySlice& lhs_shape, + const xla::ComputationDataHandle& rhs, + const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const std::vector& extend_dimensions) = 0; + + void Compile(XlaOpKernelContext* ctx) override; + + // Helper function that performs the broadcasting described by + // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same + // shape. + static std::pair + Broadcast(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& lhs, + const xla::ComputationDataHandle& rhs, + const BCast& broadcast_helper); +}; + +// Coefficient-wise binary operations that map a scalar function. Each +// BinaryMap Op expects two inputs that can be broadcast to the same +// shape and maps a (scalar,scalar)->scalar function across the zipped +// elements of its (broadcast) inputs. The base class contains pure +// virtual methods to override: description is a textual description +// of the mapped function; and BuildMapLambda adds the +// implementation of the lambda to a xla::ComputationBuilder. +class XlaBinaryMapOp : public XlaBinaryOp { + public: + explicit XlaBinaryMapOp(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} + ~XlaBinaryMapOp() override {} + + // Implement the (scalar,scalar)->scalar lambda that should be + // applied to each pair of elements of the inputs. The desired + // computation should be added to 'builder' and + // '(scalar_lhs,scalar_rhs)' are the function's inputs. + virtual void BuildMapLambda(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) = 0; + + xla::ComputationDataHandle Computation( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, + const gtl::ArraySlice& lhs_shape, + const xla::ComputationDataHandle& rhs, + const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const std::vector& extend_dimensions) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc new file mode 100644 index 0000000000..d96ff34178 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +// This OpKernel implements the Constant Op for XLA JIT +// devices. It extracts the constant Tensor from the Proto at kernel +// construction time, and then every time the Constant Op is executed +// an expression containing the constant is compiled. +class ConstantDeclarationOp : public XlaOpKernel { + public: + explicit ConstantDeclarationOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), tensor_(ctx->output_type(0)) { + const TensorProto* proto = nullptr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + // MakeTensorFromProto uses the cpu_allocator, so tensor_ is a + // "real" tensor backed by CPU memory, holding the value of the + // constant. + OP_REQUIRES_OK(ctx, MakeTensorFromProto(*proto, &tensor_)); + OP_REQUIRES( + ctx, ctx->output_type(0) == tensor_.dtype(), + errors::InvalidArgument( + "Type mismatch between value (", DataTypeString(tensor_.dtype()), + ") and dtype (", DataTypeString(ctx->output_type(0)), ")")); + } + + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetConstantOutput(0, tensor_); + } + + private: + // Extract the value of the constant from the Proto during Op kernel + // construction. The constant must be stored in a Tensor allocated + // using the cpu_allocator so that it is backed by real memory. The + // OpKernelConstruction's default allocator is the JITAllocator + // which only allocates enough space for metadata for each Tensor. + static Status MakeTensorFromProto(const TensorProto& tensor_proto, + Tensor* tensor) { + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + *tensor = parsed; + return Status::OK(); + } + + // This is a "real" tensor backed by CPU memory, containing the + // constant values. + Tensor tensor_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantDeclarationOp); +}; + +REGISTER_XLA_OP("Const", ConstantDeclarationOp); + +// This OpKernel implements the _Arg Op for XLA JIT devices. It +// associates its output with one of the arguments to a +// subcomputation. +class ArgOp : public XlaOpKernel { + public: + explicit ArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // If 'frame' is non-null, this is a function call inside an outer JIT + // compilation. Use the usual implementation of _Arg. + auto frame = ctx->call_frame(); + if (frame != nullptr) { + Tensor val; + OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + // Forwards the argument from the frame. + ctx->op_kernel_context()->set_output(0, val); + return; + } + + XlaContext& tc = XlaContext::Get(ctx); + + OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(), + errors::InvalidArgument("Invalid argument index ", index_)); + const XlaCompiler::Argument& arg = tc.args()[index_]; + + if (arg.parameter < 0) { + ctx->SetConstantOutput(0, arg.constant_value); + } else { + ctx->SetOutput(0, tc.parameter(arg.parameter)); + } + } + + private: + int index_; + DataType dtype_; + xla::PrimitiveType type_; // Corresponding XLA type. + + TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); +}; + +REGISTER_XLA_OP("_Arg", ArgOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc new file mode 100644 index 0000000000..d408ab3338 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc @@ -0,0 +1,235 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D depthwise convolution. + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/depthwise_conv_op.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +// Name of the function to use as the implementation for depthwise 2D +// convolution. Default is empty string; another possible value is +// "DummyDepthwiseConv2dKernel". +static const char kDepthwiseConv2dCustomFunc[] = ""; + +class DepthwiseConv2dNativeOp : public XlaOpKernel { + public: + explicit DepthwiseConv2dNativeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + // TODO(keveman): Refactor this (and other XLA OpKernel constructors) so + // that they use a common implementation shared with non-XLA kernels. + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + OP_REQUIRES(ctx, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(ctx, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES( + ctx, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const TensorShape input_shape = ctx->InputShape(0); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, depth_multiplier] + const TensorShape filter_shape = ctx->InputShape(1); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + OP_REQUIRES(ctx, filter_shape.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter_shape.DebugString())); + + // The last dimension for input is in_depth. It must be the same as the + // filter's in_depth. + const int64 in_depth = input_shape.dim_size(3); + OP_REQUIRES( + ctx, in_depth == filter_shape.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter_shape.dim_size(2))); + + // The last dimension for filter is depth multiplier. + const int64 depth_multiplier = filter_shape.dim_size(3); + + // The output depth is input depth x depth multiplier. + const int64 out_depth = in_depth * depth_multiplier; + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows = input_shape.dim_size(1); + const int64 filter_rows = filter_shape.dim_size(0); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols = input_shape.dim_size(2); + const int64 filter_cols = filter_shape.dim_size(1); + + // The first dimension for input is batch. + const int64 batch = input_shape.dim_size(0); + + // For now we take the stride from the second dimension only (we + // assume row = col stride, and do not support striding on the + // batch or depth dimension). + const int32 stride = strides_[1]; + + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(ctx, GetWindowedOutputSize(input_rows, filter_rows, stride, + padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(ctx, GetWindowedOutputSize(input_cols, filter_cols, stride, + padding_, &out_cols, &pad_cols)); + TensorShape out_shape({batch, out_rows, out_cols, out_depth}); + OP_REQUIRES( + ctx, out_shape.num_elements() <= 2147483647, + errors::InvalidArgument("total number of outputs should be within the " + "range of int which is used in the GPU kernel", + in_depth, " vs ", filter_shape.dim_size(2))); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + + VLOG(2) << "DepthwiseConv2dNative: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; stride = " << stride << ", pad_rows = " << pad_rows + << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " + << out_rows << ", " << out_cols << ", " << out_depth << "]"; + + xla::ComputationBuilder& b = *ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle filter = ctx->Input(1); + xla::ComputationDataHandle output; + + const string custom_function_name = kDepthwiseConv2dCustomFunc; + if (!custom_function_name.empty()) { + xla::Shape xla_out_shape; + OP_REQUIRES_OK( + ctx, TensorShapeToXLAShape(input_type(0), out_shape, &xla_out_shape)); + + // The custom function for depthwise should interpret its arguments + // as follows : + // func(T* output, + // const T* input, const T* filter, + // const int32* input_size, const int32* filter_size, + // const int32* output_size, + // int32 stride, int32 pad_rows, int32 pad_cols) + // + // where T is the type of Tensor that this kernel is registered for. + // Note that the custom call op passes uses the following calling + // convention: + // func(void* output, void** inputs) + // + // Therefore the custom function should first construct the above + // inputs by unparsing the second argument passed to it. + output = b.CustomCall( + custom_function_name, + {input, filter, + b.ConstantR1({batch, input_rows, input_cols, in_depth}), + b.ConstantR1( + {filter_rows, filter_cols, in_depth, depth_multiplier}), + b.ConstantR1({batch, out_rows, out_cols, out_depth}), + b.ConstantR0(stride), b.ConstantR0(pad_rows), + b.ConstantR0(pad_cols)}, + xla_out_shape); + } else { + // These will be used to define the bounds of each slice. + // Within the loop, the input_channel index will be modified. + gtl::InlinedVector filter_begin; + gtl::InlinedVector filter_limits; + gtl::InlinedVector input_begin; + gtl::InlinedVector input_limits; + for (int i = 0; i < 4; ++i) { + filter_begin.push_back(0); + filter_limits.push_back(filter_shape.dim_size(i)); + input_begin.push_back(0); + input_limits.push_back(input_shape.dim_size(i)); + } + + std::vector strides_for_tla{strides_[1], strides_[2]}; + + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + xla::ConvolutionDimensionNumbers dims; + dims.set_batch_dimension(0); + dims.set_feature_dimension(3); + dims.add_spatial_dimensions(1); + dims.add_spatial_dimensions(2); + + // TF filter shape is [ H, W, inC, outC ] + dims.add_kernel_spatial_dimensions(0); + dims.add_kernel_spatial_dimensions(1); + dims.set_kernel_input_feature_dimension(2); + dims.set_kernel_output_feature_dimension(3); + + // Create one convolution for each input channel + std::vector convs; + for (int i = 0; i < in_depth; ++i) { + filter_begin[2] = i; + filter_limits[2] = i + 1; + input_begin[3] = i; + input_limits[3] = i + 1; + + xla::ComputationDataHandle filter_slice = + b.Slice(filter, filter_begin, filter_limits); + xla::ComputationDataHandle input_slice = + b.Slice(input, input_begin, input_limits); + convs.push_back(b.ConvWithGeneralDimensions( + input_slice, filter_slice, strides_for_tla, xla_padding, dims)); + } + // Concatenate the per-channel convolutions along the depth dimension. + output = b.ConcatInDim(convs, 3); + } + + ctx->SetOutput(0, output); + } + + private: + std::vector strides_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); +}; + +REGISTER_XLA_OP("DepthwiseConv2dNative", DepthwiseConv2dNativeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc new file mode 100644 index 0000000000..b89109ff6a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -0,0 +1,255 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class DiagOp : public XlaOpKernel { + public: + explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + + auto dims = input_shape.dim_sizes(); + OP_REQUIRES(ctx, !dims.empty(), + errors::InvalidArgument("Expected 1 <= dims, got shape ", + input_shape.DebugString())); + + xla::ComputationDataHandle diag = ctx->Input(0); + + // Picture: + // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] + // [0, 2, 0, 0] + // [0, 0, 3, 0] + // [0, 0, 0, 4]] + + // Flattens the input to 1D. + int64 size = input_shape.num_elements(); + diag = builder->Reshape(diag, {size}); + + // Adds inter-element padding of 'size'. + xla::PaddingConfig config; + auto* dim = config.add_dimensions(); + dim->set_interior_padding(size); + diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + + // Reshapes to the final shape. + std::vector new_dims(dims.size() * 2); + std::copy(dims.begin(), dims.end(), new_dims.begin()); + std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); + diag = builder->Reshape(diag, new_dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("Diag", DiagOp); + +class DiagPartOp : public XlaOpKernel { + public: + explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + auto dims = input_shape.dim_sizes(); + + int num_dims = dims.size(); + const int out_dims = num_dims / 2; + + OP_REQUIRES(ctx, 2 <= num_dims, + errors::InvalidArgument("Expected 2 <= dims, got shape ", + input_shape.DebugString())); + OP_REQUIRES(ctx, num_dims % 2 == 0, + errors::InvalidArgument("The input tensor must have even rank; " + "got shape ", + input_shape.DebugString())); + int64 new_size = 1; + std::vector new_dims; + for (int i = 0; i < out_dims; i++) { + OP_REQUIRES( + ctx, dims[i] == dims[i + out_dims], + errors::InvalidArgument("Invalid shape ", input_shape.DebugString(), + ": dimensions ", i, " and ", i + out_dims, + " do not match.")); + new_size *= dims[i]; + new_dims.push_back(dims[i]); + } + + xla::ComputationDataHandle diag = ctx->Input(0); + + // TODO(b/30878775): use Slice with strides when supported, in place of + // the Pad -> Reshape -> Slice. + + // Picture: + // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], + // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], + // [0, 0, 3, 0] [3, 0, 0, 0, 0], + // [0, 0, 0, 4]] [4, 0, 0, 0, 0]] + // and then slice out the first column. + + // Flattens the input to 1D. + int64 size = input_shape.num_elements(); + diag = builder->Reshape(diag, {size}); + + // Adds padding after the last element of 'new_size'. + xla::PaddingConfig config; + auto* dim = config.add_dimensions(); + dim->set_edge_padding_high(new_size); + auto zero = XlaHelpers::Zero(builder, input_type(0)); + diag = builder->Pad(diag, zero, config); + + // Reshapes so the diagonal is now in the first column. + diag = builder->Reshape(diag, {new_size, new_size + 1}); + + // Slices out the first column and reshapes to the final shape. + diag = builder->Slice(diag, {0, 0}, {new_size, 1}); + diag = builder->Reshape(diag, new_dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("DiagPart", DiagPartOp); + +class MatrixDiagOp : public XlaOpKernel { + public: + explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + + auto dims = input_shape.dim_sizes(); + OP_REQUIRES(ctx, !dims.empty(), + errors::InvalidArgument("Expected 1 <= dims, got shape ", + input_shape.DebugString())); + + xla::ComputationDataHandle diag = ctx->Input(0); + + int last_dim = dims.size() - 1; + int64 last_dim_size = input_shape.dim_size(last_dim); + + // Adds inter-element padding of 'last_dim_size' to the last dimension. + xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); + auto* dim = config.mutable_dimensions(last_dim); + dim->set_interior_padding(last_dim_size); + diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + + // Reshapes to the final shape. + dims.push_back(last_dim_size); + diag = builder->Reshape(diag, dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("MatrixDiag", MatrixDiagOp); + +class MatrixDiagPartOp : public XlaOpKernel { + public: + explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + const TensorShape input_shape = ctx->InputShape(0); + auto dims = input_shape.dim_sizes(); + + OP_REQUIRES(ctx, 2 <= dims.size(), + errors::InvalidArgument("Expected 2 <= dims, got shape ", + input_shape.DebugString())); + + xla::ComputationDataHandle diag = ctx->Input(0); + + int last_dim = dims.size() - 1; + int64 last_dim_size = dims[last_dim]; + + // The smaller of the last two dimension sizes. + int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]); + + // TODO(b/30878775): use Slice with strides when supported, in place of + // the Pad -> Reshape -> Slice. + + // Picture: for each 2D matrix in the tensor's last two dimensions: + // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], + // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], + // [0, 0, 3, 0]] [3, 0, 0, 0, 0], + // and then slice out the first column. + // + // Another example, with tall and narrow input. + // [[1, 0] pad and reshape to [[1, 0, 0], + // [0, 2] =================> [2, 0, 0]] + // [0, 0] + // [0, 0]] + + // Collapses the last two dimensions. + std::vector flattened_dims(dims.begin(), dims.end() - 1); + flattened_dims.back() *= dims.back(); + diag = builder->Reshape(diag, flattened_dims); + + // Slices or pads the last dimension to 'target_size'. + int64 actual_size = flattened_dims.back(); + int64 target_size = smaller_dim_size * (last_dim_size + 1); + if (actual_size < target_size) { + xla::PaddingConfig config = + xla::MakeNoPaddingConfig(flattened_dims.size()); + auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); + dim->set_edge_padding_high(target_size - actual_size); + auto zero = XlaHelpers::Zero(builder, input_type(0)); + diag = builder->Pad(diag, zero, config); + } else if (actual_size > target_size) { + std::vector start(flattened_dims.size(), 0); + std::vector limits(flattened_dims.begin(), flattened_dims.end()); + limits[flattened_dims.size() - 1] = target_size; + diag = builder->Slice(diag, start, limits); + } + + // Reshape so the target values are in the first position of the last + // dimension. + std::vector unflattened_dims(dims.begin(), dims.end()); + dims[last_dim - 1] = smaller_dim_size; + dims[last_dim] = last_dim_size + 1; + diag = builder->Reshape(diag, dims); + + // Slices out the first column and reshapes to the final shape. + std::vector start(dims.size(), 0); + std::vector limits(dims.begin(), dims.end()); + limits[last_dim] = 1; + diag = builder->Slice(diag, start, limits); + + // Collapses away the last dimension. + dims.pop_back(); + diag = builder->Reshape(diag, dims); + + ctx->SetOutput(0, diag); + } +}; + +REGISTER_XLA_OP("MatrixDiagPart", MatrixDiagPartOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc new file mode 100644 index 0000000000..2936e79261 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific dynamic stitch Op. + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { +namespace { + +class DynamicStitchOp : public XlaOpKernel { + public: + explicit DynamicStitchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES( + ctx, ctx->num_inputs() > 0, + errors::InvalidArgument("DynamicStitchOp: Must have some inputs")); + OP_REQUIRES(ctx, ctx->num_inputs() % 2 == 0, + errors::InvalidArgument( + "DynamicStitchOp: Must have even number of arguments")); + // Compute expected input signature + const int n = ctx->num_inputs() / 2; + const DataType dt = ctx->input_type(n); + DataTypeVector expected; + for (int i = 0; i < n; i++) { + expected.push_back(DT_INT32); + } + for (int i = 0; i < n; i++) { + expected.push_back(dt); + } + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected, {dt})); + } + + void Compile(XlaOpKernelContext* ctx) override { + // Validate that data_shape[i] = indices[i].shape() + constant + std::vector indices_input; + OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input)); + + std::vector data; + std::vector data_shapes; + OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes)); + + std::vector indices(indices_input.size()); + + const TensorShape& data0_shape = data_shapes[0]; + const TensorShape indices0_shape = + XLAShapeToTensorShape(indices_input[0].shape()); + for (int input_num = 0; input_num < indices_input.size(); input_num++) { + const TensorShape indices_shape = + XLAShapeToTensorShape(indices_input[input_num].shape()); + const TensorShape& data_shape = data_shapes[input_num]; + OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), + errors::InvalidArgument( + "data[", input_num, "].shape = ", + data_shape.DebugString(), " does not start with indices[", + input_num, "].shape = ", indices_shape.DebugString())); + OP_REQUIRES(ctx, + input_num == 0 || SameExtraShape(data0_shape, indices0_shape, + data_shape, indices_shape), + errors::InvalidArgument( + "Need data[0].shape[", indices0_shape.dims(), + ":] = data[", input_num, "].shape[", indices_shape.dims(), + ":], got data[0].shape = ", data0_shape.DebugString(), + ", data[", input_num, "].shape = ", + data_shape.DebugString(), ", indices[0].shape = ", + indices0_shape.DebugString(), ", indices[", input_num, + "].shape = ", indices_shape.DebugString())); + + OP_REQUIRES_OK(ctx, + XlaHelpers::ReshapeLiteral(indices_input[input_num], + {indices_shape.num_elements()}, + &indices[input_num])); + } + + // Find which slice will be used for each index. If the same index + // appears in multiple inputs, the last one is used. The logic + // here is different from that in third_party/tensorflow because + // it is important for XLA that there be a well-formed Concat + // operation at the end. The existing CPU/GPU code copies multiple + // source slices to the same destination slice if there are + // repeated indices, whereas the XLA code works out which + // source slice will 'win' and only uses that in the Concat. + int max_index = -1; + for (int input_num = 0; input_num < indices.size(); input_num++) { + for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { + max_index = std::max( + max_index, xla::LiteralUtil::Get(indices[input_num], {i})); + } + } + int number_of_indices = max_index + 1; + OP_REQUIRES(ctx, number_of_indices > 0, + errors::InvalidArgument("no indices supplied")); + // Construct the reverse mapping, for each index, of which slice of which + // input it comes from. + std::vector src_input_vector(number_of_indices); + std::vector src_slice_vector(number_of_indices); + std::vector src_index_used(number_of_indices); + int index_used_count = 0; + for (int input_num = 0; input_num < indices.size(); input_num++) { + for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { + int index = xla::LiteralUtil::Get(indices[input_num], {i}); + src_input_vector[index] = input_num; + src_slice_vector[index] = i; + if (!src_index_used[index]) { + src_index_used[index] = true; + ++index_used_count; + } + } + } + OP_REQUIRES(ctx, index_used_count == number_of_indices, + errors::InvalidArgument("not all indices are used")); + + // Look up all the children expressions that represent the data + // inputs. + std::vector input(indices.size()); + for (int input_num = 0; input_num < indices.size(); input_num++) { + TensorShape new_shape; + // first reshaped dimension is the number of indices for this input. + new_shape.AddDim(indices[input_num].shape().dimensions(0)); + // Then the rest are the common extra shape. + for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { + new_shape.AddDim(data0_shape.dim_size(d)); + } + // Get the data, shaped appropriately. + auto handle = data[input_num]; + if (new_shape == data_shapes[input_num]) { + input[input_num] = handle; + } else { + input[input_num] = + ctx->builder()->Reshape(handle, new_shape.dim_sizes()); + } + } + + // Set up the vectors for slicing: the first dimension will vary + // slice by slice, and the rest take the full common extra shape. + std::vector slice_start(1 + data0_shape.dims() - + indices0_shape.dims()); + std::vector slice_limit(1 + data0_shape.dims() - + indices0_shape.dims()); + for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { + slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); + } + std::vector to_concat(number_of_indices); + for (int index_num = 0; index_num < number_of_indices; index_num++) { + const auto& expression = input[src_input_vector[index_num]]; + // Take the appropriate slice of data. + slice_start[0] = src_slice_vector[index_num]; + slice_limit[0] = src_slice_vector[index_num] + 1; + // And place it in the concat list in the place indicated by + // the index. + to_concat[index_num] = + ctx->builder()->Slice(expression, slice_start, slice_limit); + } + + ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); + } + + private: + // Check if data0_shape[indices0.dims():] == data1_shape[indices1.dims():] + static bool SameExtraShape(const TensorShape& data0_shape, + const TensorShape& indices0, + const TensorShape& data1_shape, + const TensorShape& indices1) { + const int extra0 = data0_shape.dims() - indices0.dims(); + const int extra1 = data1_shape.dims() - indices1.dims(); + if (extra0 != extra1) return false; + for (int i = 0; i < extra0; i++) { + if (data0_shape.dim_size(indices0.dims() + i) != + data1_shape.dim_size(indices1.dims() + i)) { + return false; + } + } + return true; + } +}; + +REGISTER_XLA_OP("DynamicStitch", DynamicStitchOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc new file mode 100644 index 0000000000..918c80aad8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Fill Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +namespace { + +class FillOp : public XlaOpKernel { + public: + explicit FillOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // The output of this Op is a tensor of shape 'dims_shape' with each + // element set to the scalar 'dims_literal'. + const TensorShape dims_shape = ctx->InputShape(0); + const TensorShape value_shape = ctx->InputShape(1); + OP_REQUIRES( + ctx, IsLegacyVector(dims_shape), + errors::InvalidArgument("dims must be a vector of int32, got shape ", + dims_shape.DebugString())); + OP_REQUIRES(ctx, IsLegacyScalar(value_shape), + errors::InvalidArgument("value must be a scalar, got shape ", + value_shape.DebugString())); + // Evaluate the 'dims' constant input, reshaping to a vector if it + // was a 'legacy' vector (secretly a scalar). + xla::Literal dims_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped( + 0, {dims_shape.num_elements()}, &dims_literal)); + + // Convert the dims literal into a vector that we can pass to + // ComputationBuilder. + std::vector broadcast; + for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { + broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); + } + // Look up the value input, reshaping to a scalar if it was a + // 'legacy' scalar (secretly a vector). + xla::ComputationDataHandle data = ctx->Input(1); + if (value_shape.dims() > 0) { + CHECK_EQ(value_shape.dims(), 1); + data = ctx->builder()->Reshape(data, {}); + } + // Emit the actual computation, which broadcasts the scalar to the + // desired shape. + auto result = ctx->builder()->Broadcast(data, broadcast); + + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP("Fill", FillOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc new file mode 100644 index 0000000000..53f2196dc5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +const char* const kGradientOp = "SymbolicGradient"; + +// Implementations of _ListToArray and _ArrayToList for functions. +class PassOn : public XlaOpKernel { + public: + explicit PassOn(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(), + errors::Internal("#inputs != #outputs : ", ctx->num_inputs(), + " vs. ", ctx->num_outputs())); + for (int i = 0; i < ctx->num_inputs(); ++i) { + OP_REQUIRES( + ctx, input_type(i) == output_type(i), + errors::Internal("Input and output types for position ", i, + " do not match: ", DataTypeString(input_type(i)), + " vs. ", DataTypeString(output_type(i)))); + } + } + + void Compile(XlaOpKernelContext* ctx) override { + for (int i = 0; i < ctx->num_inputs(); ++i) { + ctx->SetOutput(i, ctx->Input(i)); + } + } +}; + +REGISTER_XLA_OP("_ListToArray", PassOn); +REGISTER_XLA_OP("_ArrayToList", PassOn); + +// TODO(phawkins): this is an almost exact copy of the SymbolicGradientOp +// implementation from regular Tensorflow. Once XLA has been open sourced +// merge the two implementations. (Note: this implementation propagates the +// step_resource_manager). +class SymbolicGradientOp : public AsyncOpKernel { + public: + explicit SymbolicGradientOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), handle_(-1) {} + + ~SymbolicGradientOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library is provided."), + done); + + OP_REQUIRES_OK_ASYNC( + ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done); + + FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.runner = ctx->runner(); + opts.step_container = ctx->step_container(); + std::vector args; + args.reserve(ctx->num_inputs()); + for (int i = 0; i < ctx->num_inputs(); ++i) { + args.push_back(ctx->input(i)); + } + std::vector* rets = new std::vector; + lib->Run( + opts, handle_, args, rets, [ctx, done, rets](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else if (rets->size() != ctx->num_outputs()) { + ctx->SetStatus(errors::InvalidArgument( + "SymGrad expects to return ", ctx->num_outputs(), + " tensor(s), but get ", rets->size(), " tensor(s) instead.")); + } else { + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + done(); + }); + } + + private: + FunctionLibraryRuntime::Handle handle_; + + TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp); +}; + +REGISTER_XLA_OP(kGradientOp, SymbolicGradientOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc new file mode 100644 index 0000000000..b98d386479 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class GatherOp : public XlaOpKernel { + public: + explicit GatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape params_shape = ctx->InputShape(0); + const TensorShape indices_shape = ctx->InputShape(1); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVectorOrHigher(params_shape), + errors::InvalidArgument("params must be at least 1 dimensional")); + + DataType index_type = input_type(1); + OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("index must be int32 or int64")); + + // Check that we have enough index space. + const int64 limit = index_type == DT_INT32 + ? std::numeric_limits::max() + : std::numeric_limits::max(); + OP_REQUIRES( + ctx, params_shape.dim_size(0) <= limit, + errors::InvalidArgument("params.shape[0] too large for ", + DataTypeString(index_type), " indexing: ", + params_shape.dim_size(0), " > ", limit)); + + // The result shape is indices.shape + params.shape[1:]. + TensorShape result_shape = indices_shape; + for (int i = 1; i < params_shape.dims(); i++) { + result_shape.AddDim(params_shape.dim_size(i)); + } + + XlaContext& tc = XlaContext::Get(ctx); + OP_REQUIRES( + ctx, tc.allow_cpu_custom_calls(), + errors::InvalidArgument("Gather op requires CustomCall on CPU")); + + xla::ComputationBuilder& b = *ctx->builder(); + + // Call gather_xla_float_kernel (from gather_op_kernel_float.cc). + // XLA passes to the function, so it is not included here. + std::vector args; + args.push_back(tc.GetOrCreateRuntimeContextParameter()); + args.push_back(b.ConstantLiteral( + *xla::LiteralUtil::CreateR0(indices_shape.num_elements()))); + args.push_back(b.ConstantLiteral( + *xla::LiteralUtil::CreateR0(params_shape.dim_size(0)))); + args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0( + params_shape.num_elements() / params_shape.dim_size(0)))); + args.push_back(ctx->Input(0)); + args.push_back(ctx->Input(1)); + + xla::Shape xla_out_shape; + OP_REQUIRES_OK( + ctx, TensorShapeToXLAShape(DT_FLOAT, result_shape, &xla_out_shape)); + + // Call the custom code with args: + xla::ComputationDataHandle output; + if (index_type == DT_INT32) { + output = b.CustomCall("gather_float_int32_xla_impl", args, xla_out_shape); + } else { + output = b.CustomCall("gather_float_int64_xla_impl", args, xla_out_shape); + } + + ctx->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); +}; + +REGISTER_XLA_OP("Gather", GatherOp); + +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Gather") + .TypeConstraint("Tparams", DT_FLOAT) + .TypeConstraint("Tindices", {DT_INT32, DT_INT64})); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc new file mode 100644 index 0000000000..eff23bd77d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/platform/dynamic_annotations.h" + +namespace tensorflow { + +EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { + // data is managed by the JIT code so msan can't tell it's initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 6 * sizeof(void*)); + + int64 indices_size = *static_cast(data[1]); + int64 params_x = *static_cast(data[2]); + int64 params_y = *static_cast(data[3]); + + float* in = static_cast(data[4]); + + int32* indices = static_cast(data[5]); + Eigen::DSizes in_eig_sizes; + in_eig_sizes[0] = params_x; + in_eig_sizes[1] = params_y; + tensorflow::TTypes::ConstMatrix in_eig(in, in_eig_sizes); + + Eigen::DSizes indices_eig_sizes; + indices_eig_sizes[0] = indices_size; + tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); + + Eigen::DSizes out_eig_sizes; + out_eig_sizes[0] = indices_size; + out_eig_sizes[1] = params_y; + tensorflow::TTypes::Matrix out_eig(out, out_eig_sizes); + + tensorflow::functor::GatherFunctorCPU f; + const int64 bad_i = f(in_eig, indices_eig, out_eig); + if (bad_i != -1) { + tensorflow::XlaLocalRuntimeContext* runtime_context = + static_cast(data[0]); + runtime_context->error = true; + runtime_context->error_msg = "Invalid index for gather"; + for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; + } +} + +} // namespace tensorflow + +// Implements gather on CPU. This is called by an XLA custom call, set up by +// gather_op.cc. +extern "C" void __attribute__((visibility("default"))) +gather_float_int32_xla_impl(float* out, void** data) { + tensorflow::gather_float_int32_xla_impl(out, data); +} diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc new file mode 100644 index 0000000000..ae31f6f200 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/platform/dynamic_annotations.h" + +namespace tensorflow { + +EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { + // data is managed by the JIT code so msan can't tell it's initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 6 * sizeof(void*)); + + int64 indices_size = *static_cast(data[1]); + int64 params_x = *static_cast(data[2]); + int64 params_y = *static_cast(data[3]); + + float* in = static_cast(data[4]); + + int64* indices = static_cast(data[5]); + Eigen::DSizes in_eig_sizes; + in_eig_sizes[0] = params_x; + in_eig_sizes[1] = params_y; + tensorflow::TTypes::ConstMatrix in_eig(in, in_eig_sizes); + + Eigen::DSizes indices_eig_sizes; + indices_eig_sizes[0] = indices_size; + tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); + + Eigen::DSizes out_eig_sizes; + out_eig_sizes[0] = indices_size; + out_eig_sizes[1] = params_y; + tensorflow::TTypes::Matrix out_eig(out, out_eig_sizes); + + tensorflow::functor::GatherFunctorCPU f; + const int64 bad_i = f(in_eig, indices_eig, out_eig); + if (bad_i != -1) { + tensorflow::XlaLocalRuntimeContext* runtime_context = + static_cast(data[0]); + runtime_context->error = true; + runtime_context->error_msg = "Invalid index for gather"; + for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; + } +} + +} // namespace tensorflow + +// Implements gather on CPU. This is called by an XLA custom call, set up by +// gather_op.cc. +extern "C" void __attribute__((visibility("default"))) +gather_float_int64_xla_impl(float* out, void** data) { + tensorflow::gather_float_int64_xla_impl(out, data); +} diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc new file mode 100644 index 0000000000..01417a3cdf --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" + +namespace tensorflow { +namespace { + +class IdentityOp : public XlaOpKernel { + public: + explicit IdentityOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(IdentityOp); +}; + +REGISTER_XLA_OP("Identity", IdentityOp); +REGISTER_XLA_OP("PreventGradient", IdentityOp); +REGISTER_XLA_OP("StopGradient", IdentityOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc new file mode 100644 index 0000000000..293705e39f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of indexing ops. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { +namespace { + +// The logic below uses a custom-call to implement argmax. +// +// TODO(toddw): We can implement argmax using existing XLA ops. The idea is +// to use SelectAndScatter to create a tensor initialized to 0, where the max +// value along dim is set to 1. Then take the dot-product of that against a +// vector of indices [0,dim_size), which yields the result. As a detail, we +// might need to reshape before and afterwards, since the XLA Dot operator +// only performs the sum of products over dimension 0. +// +// rs = Reshape(input, ...) // reshape so dim is inner-most +// one_max = SelectAndScatter(rs, greater_than, +// {1,1,...,dim_size}, {1,1,...,dim_size}, +// VALID, [1], 0, add) +// indices = [0,1,2,...,dim_size-1] +// max_index = Dot(one_max, indices) +// result = Reshape(max_index, ...) // reshape back to original +// +// Also see b/29507024 for first-class XLA support for indexing ops. + +class ArgMaxOp : public XlaOpKernel { + public: + explicit ArgMaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape dimension_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape), + errors::InvalidArgument( + "dim must be a scalar, but received tensor of shape: ", + dimension_shape.DebugString())); + + // We require that the dimension argument is a constant, since it lets us + // dispatch to a specialized custom-call function without any run-time + // overhead, when compiling ahead-of-time. + // + // TODO(toddw): We could remove this requirement if necessary; we'd also + // need to update const_analysis.cc. However it seems likely that a native + // XLA op would have the same requirement. + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + const int32 dim = xla::LiteralUtil::Get(literal, {}); + OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); + OP_REQUIRES( + ctx, dim < input_shape.dims(), + errors::InvalidArgument("dim must be < input rank (", + input_shape.dims(), "), but got: ", dim)); + const int64 dim_size = input_shape.dim_size(dim); + OP_REQUIRES( + ctx, dim_size > 0, + errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", + input_shape.DebugString())); + + // The output shape is the input shape contracted along dim. + TensorShape output_shape; + for (int d = 0; d < input_shape.dims() - 1; ++d) { + output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + } + + // For now we use a custom-call, only for the 1d and 2d cases. + OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), + errors::InvalidArgument( + "ArgMax implementation requires a CustomCall on CPU")); + xla::ComputationBuilder& b = *ctx->builder(); + + // XLA passes to the function, so it is not included here. + std::vector args; + args.push_back(ctx->Input(0)); + args.push_back(b.ConstantLiteral( + *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + if (input_shape.dims() > 1) { + // Don't bother passing the output shape and dim for the 1d case, since + // the shape is always a scalar and the dim is always 0. + args.push_back(b.ConstantLiteral( + *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + args.push_back( + b.ConstantLiteral(*xla::LiteralUtil::CreateR0(dim))); + } + + xla::Shape xla_shape = + xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + + // Tell XLA to call the custom code, defined in + // index_ops_kernel_argmax_float_1d.cc. + xla::ComputationDataHandle output; + switch (input_shape.dims()) { + case 1: + output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + break; + case 2: + output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + break; + default: + OP_REQUIRES(ctx, false, + errors::Unimplemented( + "Argmax is only implemented for 1d and 2d tensors" + ", but got shape: ", + input_shape.DebugString())); + } + ctx->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxOp); +}; + +REGISTER_XLA_OP("ArgMax", ArgMaxOp); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("ArgMax").TypeConstraint("T", DT_FLOAT)); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc new file mode 100644 index 0000000000..0033a949a3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { + // data is managed by the JIT code so msan can't tell it's initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*)); + + float* input = static_cast(data[0]); + int64 input_size = *static_cast(data[1]); + + Eigen::DSizes in_eig_sizes(input_size); + TTypes::ConstTensor in_eig(input, in_eig_sizes); + + Eigen::DSizes out_eig_sizes; + int64* out_t = static_cast(out); + TTypes::Tensor out_eig(out_t, out_eig_sizes); + + out_eig = in_eig.argmax(0).cast(); +} + +} // namespace tensorflow + +// Implements argmax on CPU. This is called by an XLA custom call, set up by +// index_ops.cc. +extern "C" void __attribute__((visibility("default"))) +argmax_float_1d_xla_impl(void* out, void** data) { + tensorflow::argmax_float_1d_xla_impl(out, data); +} diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc new file mode 100644 index 0000000000..be8ad2317c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { + // data is managed by the JIT code so msan can't tell it's initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*)); + + float* in = static_cast(data[0]); + int64* in_sizes = static_cast(data[1]); + int64* out_sizes = static_cast(data[2]); + int32 dim = *static_cast(data[3]); + + Eigen::DSizes in_eig_sizes(in_sizes[0], in_sizes[1]); + TTypes::ConstTensor in_eig(in, in_eig_sizes); + + int64* out_t = static_cast(out); + Eigen::DSizes out_eig_sizes(out_sizes[0]); + TTypes::Tensor out_eig(out_t, out_eig_sizes); + + out_eig = in_eig.argmax(dim).cast(); +} + +} // namespace tensorflow + +// Implements argmax on CPU. This is called by an XLA custom call, set up by +// index_ops.cc. +extern "C" void __attribute__((visibility("default"))) +argmax_float_2d_xla_impl(void* out, void** data) { + tensorflow::argmax_float_2d_xla_impl(out, data); +} diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc new file mode 100644 index 0000000000..248984bcfe --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class L2LossOp : public XlaOpKernel { + public: + explicit L2LossOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + + DataType dtype = ctx->input_type(0); + xla::ComputationBuilder* b = ctx->builder(); + + auto zero = XlaHelpers::Zero(b, dtype); + auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); + const xla::Computation& add = *ctx->GetOrCreateAdd(dtype); + + std::vector dims(input_shape.dims()); + std::iota(dims.begin(), dims.end(), 0); + + // output = sum(t ** 2) / 2 + auto x = ctx->Input(0); + ctx->SetOutput(0, b->Div(b->Reduce(b->Mul(x, x), zero, add, dims), two)); + } +}; + +REGISTER_XLA_OP("L2Loss", L2LossOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc new file mode 100644 index 0000000000..93966d3d5a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +// Local response normalization +class LRNOp : public XlaOpKernel { + public: + explicit LRNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_)); + + // TODO(phawkins): handle non-float types for attributes. + OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape in_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, in_shape.dims() == 4, + errors::InvalidArgument("in must be 4-dimensional")); + + xla::ComputationBuilder* builder = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + + // sqr_sum[a, b, c, d] = + // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) + // output = input / (bias + alpha * sqr_sum) ** beta + + // We use a window of depth_radius_ * 2 + 1, to account for the current + // element and a depth_radius_ on either side. + auto squared = builder->Mul(input, input); + auto sqr_sum = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, input_type(0)), + *ctx->GetOrCreateAdd(input_type(0)), + /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, + /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + + auto scale = builder->Pow( + builder->Add(builder->ConstantR0(bias_), + builder->Mul(builder->ConstantR0(alpha_), sqr_sum)), + builder->ConstantR0(-beta_)); + + ctx->SetOutput(0, builder->Mul(input, scale)); + } + + private: + int64 depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +REGISTER_XLA_OP("LRN", LRNOp); + +class LRNGradOp : public XlaOpKernel { + public: + explicit LRNGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_)); + + // TODO(phawkins): handle non-float types for attributes. + OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape in_grads_shape = ctx->InputShape(0); + const TensorShape in_image_shape = ctx->InputShape(1); + const TensorShape out_image_shape = ctx->InputShape(2); + + OP_REQUIRES(ctx, in_grads_shape.dims() == 4 && in_image_shape.dims() == 4, + errors::InvalidArgument("inputs must be 4-dimensional")); + const int64 batch = in_grads_shape.dim_size(0); + const int64 rows = in_grads_shape.dim_size(1); + const int64 cols = in_grads_shape.dim_size(2); + const int64 depth = in_grads_shape.dim_size(3); + OP_REQUIRES( + ctx, in_image_shape.dim_size(0) == batch && + in_image_shape.dim_size(1) == rows && + in_image_shape.dim_size(2) == cols && + in_image_shape.dim_size(3) == depth && + out_image_shape.dim_size(0) == batch && + out_image_shape.dim_size(1) == rows && + out_image_shape.dim_size(2) == cols && + out_image_shape.dim_size(3) == depth, + errors::InvalidArgument( + "input_grads, input_image, and out_image should have the same " + "shape")); + + xla::ComputationBuilder* builder = ctx->builder(); + xla::ComputationDataHandle in_grads = ctx->Input(0); + xla::ComputationDataHandle in_image = ctx->Input(1); + xla::ComputationDataHandle out_image = ctx->Input(2); + + // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python + // pseudo-code, the Eigen code does this for each spatial position: + // grads = [0.0] * depth + // for j in range(depth): + // depth_begin = max(0, j - depth_radius) + // depth_end = min(depth, j + depth_radius + 1) + // + // norm = 0 + // for k in range(depth_begin, depth_end): + // norm += in_image[k] * in_image[k] + // norm = alpha * norm + bias + // + // for k in range(depth_begin, depth_end): + // dyi = -2.0 * alpha * beta * in_image[k] * out_image[j] / norm + // if k == j: + // dyi += norm ** (-beta) + // dyi *= out_grads[j] + // grads[k] += dyi + + auto squared = builder->Mul(in_image, in_image); + auto sqr_sum = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, input_type(0)), + *ctx->GetOrCreateAdd(input_type(0)), + /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, + /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + + auto norm = + builder->Add(builder->ConstantR0(bias_), + builder->Mul(builder->ConstantR0(alpha_), sqr_sum)); + + auto dy = builder->Mul( + builder->Mul(builder->ConstantR0(-2.0f * alpha_ * beta_), + builder->Div(out_image, norm)), + in_grads); + + auto dy_reduced = builder->ReduceWindow( + dy, XlaHelpers::Zero(builder, input_type(0)), + *ctx->GetOrCreateAdd(input_type(0)), + /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, + /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + + xla::ComputationDataHandle gradients = builder->Add( + builder->Mul(in_image, dy_reduced), + builder->Mul(in_grads, + builder->Pow(norm, builder->ConstantR0(-beta_)))); + + ctx->SetOutput(0, gradients); + } + + private: + int64 depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +REGISTER_XLA_OP("LRNGrad", LRNGradOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc new file mode 100644 index 0000000000..5af6a79f3e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific MatMul Op. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class MatMulOp : public XlaOpKernel { + public: + explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + if (is_sparse) { + // SparseMatMul is actually dense matmul with a hint that one or + // both of the inputs may contain a lot of zeroes. On CPU these + // inputs are dynamically converted to sparse representation + // before multiplication. For now in XLA we ignore the hints + // and always do dense multiplication. + bool dummy_is_sparse; + OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &dummy_is_sparse)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &dummy_is_sparse)); + } + } + + ~MatMulOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape a_shape = ctx->InputShape(0); + const TensorShape b_shape = ctx->InputShape(1); + + // Check that the dimensions of the two matrices are valid. + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_shape), + errors::InvalidArgument("In[0] is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b_shape), + errors::InvalidArgument("In[1] is not a matrix")); + int first_index = transpose_a_ ? 0 : 1; + int second_index = transpose_b_ ? 1 : 0; + + OP_REQUIRES(ctx, + a_shape.dim_size(first_index) == b_shape.dim_size(second_index), + errors::InvalidArgument("Matrix size-compatible: In[0]: ", + a_shape.DebugString(), ", In[1]: ", + b_shape.DebugString())); + + xla::ComputationDataHandle a = ctx->Input(0); + xla::ComputationDataHandle b = ctx->Input(1); + auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; + auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; + ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); + } + + private: + bool transpose_a_; + bool transpose_b_; +}; + +REGISTER_XLA_OP("MatMul", MatMulOp); + +class SparseMatMulOp : public MatMulOp { + public: + explicit SparseMatMulOp(OpKernelConstruction* ctx) : MatMulOp(ctx, true) {} + + ~SparseMatMulOp() override = default; +}; + +REGISTER_XLA_OP("SparseMatMul", SparseMatMulOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc new file mode 100644 index 0000000000..806bfc604f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -0,0 +1,24 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { + +REGISTER_XLA_OP("NoOp", NoOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc new file mode 100644 index 0000000000..7456d92de0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA Pack operator. + +#include +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +class PackOp : public XlaOpKernel { + public: + explicit PackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + std::vector values; + std::vector shapes; + OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); + const int num = values.size(); + + OP_REQUIRES(ctx, num >= 0, + errors::InvalidArgument("Pack requires >= 1 arguments")); + + // Verify that all input shapes match + for (int i = 1; i < num; i++) { + OP_REQUIRES(ctx, shapes[0].IsSameSize(shapes[i]), + errors::InvalidArgument( + "Shapes of all inputs must match: values[0].shape = ", + shapes[0].DebugString(), " != values[", i, "].shape = ", + shapes[i].DebugString())); + } + + int expanded_num_dims = shapes[0].dims() + 1; + int axis = axis_; + if (axis < 0) axis += expanded_num_dims; + + OP_REQUIRES(ctx, 0 <= axis && axis < expanded_num_dims, + errors::InvalidArgument("axis = ", axis_, " not in [", + -expanded_num_dims, ", ", + expanded_num_dims, ")")); + + std::vector reshaped_inputs(num); + + TensorShape child_shape(shapes[0]); + child_shape.InsertDim(axis, 1); + + for (int i = 0; i < num; ++i) { + // Reshape the inputs to have an extra dimension of size 1. + reshaped_inputs[i] = + ctx->builder()->Reshape(values[i], child_shape.dim_sizes()); + } + + ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis)); + } + + private: + int axis_; +}; + +REGISTER_XLA_OP("Pack", PackOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc new file mode 100644 index 0000000000..2846414c5e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +namespace { + +class PadOp : public XlaOpKernel { + public: + explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape pad_shape = ctx->InputShape(1); + const int dims = input_shape.dims(); + OP_REQUIRES( + ctx, + TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, + errors::InvalidArgument("paddings must be a matrix with 2 columns: ", + pad_shape.DebugString())); + const int fixed_dims = + (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) + ? 1 + : dims; + OP_REQUIRES( + ctx, fixed_dims == pad_shape.dim_size(0), + errors::InvalidArgument( + "The first dimension of paddings must be the rank of inputs", + pad_shape.DebugString(), " ", input_shape.DebugString())); + + if (fixed_dims == 0) { + // Tensor is rank 0. Return it unchanged. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + // Evaluate the 'padding' constant input, reshaping to a matrix. + xla::Literal pad_literal; + OP_REQUIRES_OK( + ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + + xla::PaddingConfig config; + for (int i = 0; i < fixed_dims; ++i) { + auto* dim = config.add_dimensions(); + int before = xla::LiteralUtil::Get(pad_literal, {i, 0}); + int after = xla::LiteralUtil::Get(pad_literal, {i, 1}); + OP_REQUIRES(ctx, before >= 0 && after >= 0, + errors::InvalidArgument("Paddings must be non-negative: ", + before, " ", after)); + dim->set_edge_padding_low(before); + dim->set_edge_padding_high(after); + } + + auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); + ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + } +}; + +REGISTER_XLA_OP("Pad", PadOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc new file mode 100644 index 0000000000..7a1ce2db85 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -0,0 +1,374 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA specific pooling ops. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/pooling_ops_common.h" + +namespace tensorflow { +namespace { + +// Superclass of pooling ops. +class PoolingOp : public XlaOpKernel { + public: + explicit PoolingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + // Data format doesn't matter since the kernel is specified explicitly. + std::vector ksize_int; + std::vector stride_int; + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); + OP_REQUIRES(ctx, ksize_int.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); + OP_REQUIRES(ctx, stride_int.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + for (int i = 0; i < 4; ++i) { + ksize_.push_back(ksize_int[i]); + stride_.push_back(stride_int[i]); + } + Padding padding; + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); + padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + } + + // Method that builds an initial value to use in reductions. + virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, + DataType data_type) = 0; + + // The reduction operation to apply to each window. + virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, + DataType dtype) = 0; + + // A post-processing operation to apply on the outputs of the ReduceWindow. + virtual xla::ComputationDataHandle PostProcessOutput( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, + DataType dtype, const TensorShape& input_shape) = 0; + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const TensorShape input_shape = ctx->InputShape(0); + + const DataType type = input_type(0); + xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( + input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_, + stride_, padding_); + ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); + } + + protected: + std::vector ksize_; + std::vector stride_; + xla::Padding padding_; +}; + +class MaxPoolOp : public PoolingOp { + public: + explicit MaxPoolOp(OpKernelConstruction* ctx) : PoolingOp(ctx) {} + + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, + DataType data_type) override { + return XlaHelpers::MinValue(b, data_type); + } + + const xla::Computation* Reduction(XlaOpKernelContext* ctx, + DataType dtype) override { + return ctx->GetOrCreateMax(dtype); + } + + xla::ComputationDataHandle PostProcessOutput( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, + DataType dtype, const TensorShape& input_shape) override { + return output; + } +}; + +REGISTER_XLA_OP("MaxPool", MaxPoolOp); + +// Common computation shared between AvgPool and AvgPoolGrad. Divide each +// element of an image by the count of elements that contributed to that +// element during pooling. +static xla::ComputationDataHandle AvgPoolDivideByCount( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, + DataType dtype, const TensorShape& input_shape, xla::Padding padding, + const std::vector& ksize, const std::vector& stride, + TensorFormat data_format) { + if (padding == xla::Padding::kValid) { + // In VALID padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to + // get the average. + int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, + [](int64 a, int64 b) { return a * b; }); + + auto divisor = + XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); + return ctx->builder()->Div(output, divisor); + } else { + // For SAME padding, the padding shouldn't be included in the + // counts. We use another ReduceWindow to find the right counts. + + // TODO(phawkins): use a less brute-force way to compute this. Only + // the boundary regions will have interesting values here. + + int height_dim = GetTensorDimIndex(data_format, 'H'); + int width_dim = GetTensorDimIndex(data_format, 'W'); + CHECK_LT(height_dim, width_dim); + + // Build a matrix of all 1s, with the same width/height as the input. + auto ones = ctx->builder()->Broadcast( + XlaHelpers::One(ctx->builder(), dtype), + {input_shape.dim_size(height_dim), input_shape.dim_size(width_dim)}); + + // Perform a ReduceWindow with the same window size, strides, and padding + // to count the number of contributions to each result element. + auto counts = ctx->builder()->ReduceWindow( + ones, XlaHelpers::Zero(ctx->builder(), dtype), + *ctx->GetOrCreateAdd(dtype), {ksize[height_dim], ksize[width_dim]}, + {stride[height_dim], stride[width_dim]}, xla::Padding::kSame); + + return ctx->builder()->Div(output, counts, {height_dim, width_dim}); + } +} + +class AvgPoolOp : public PoolingOp { + public: + explicit AvgPoolOp(OpKernelConstruction* ctx) : PoolingOp(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } + + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, + DataType data_type) override { + return XlaHelpers::Zero(b, data_type); + } + + const xla::Computation* Reduction(XlaOpKernelContext* ctx, + DataType dtype) override { + return ctx->GetOrCreateAdd(dtype); + } + + xla::ComputationDataHandle PostProcessOutput( + XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, + DataType dtype, const TensorShape& input_shape) override { + return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, + ksize_, stride_, data_format_); + } + + private: + TensorFormat data_format_; +}; + +REGISTER_XLA_OP("AvgPool", AvgPoolOp); + +// The operation to compute MaxPool gradients. +// It takes three inputs: +// - The original input tensor +// - The original output tensor +// - Backprop tensor for output +// It produces one output: backprop tensor for input. +class MaxPoolGradOp : public XlaOpKernel { + public: + explicit MaxPoolGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES(ctx, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + OP_REQUIRES(ctx, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape tensor_in_shape = ctx->InputShape(0); + const TensorShape tensor_out_shape = ctx->InputShape(1); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // For maxpooling, tensor_in should have 4 dimensions. + OP_REQUIRES(ctx, tensor_in_shape.dims() == 4, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + OP_REQUIRES(ctx, tensor_out_shape.dims() == 4, + errors::InvalidArgument("tensor_out must be 4-dimensional")); + // For maxpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + + // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate + // whether this is a good time/space tradeoff. + auto input = ctx->Input(0); + auto out_backprop = ctx->Input(2); + + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + xla::PrimitiveType element_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); + xla::ComputationDataHandle init_value = + XlaHelpers::Zero(ctx->builder(), input_type(2)); + auto select = CreateScalarGeComputation(element_type, ctx->builder()); + auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); + xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter( + input, select, ksize_, stride_, xla_padding, out_backprop, init_value, + scatter); + + ctx->SetOutput(0, gradients); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_; +}; + +REGISTER_XLA_OP("MaxPoolGrad", MaxPoolGradOp); + +// Average-pooling gradient +class AvgPoolGradOp : public XlaOpKernel { + public: + explicit AvgPoolGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES(ctx, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + OP_REQUIRES(ctx, stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape gradients_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape)); + + const TensorShape out_backprop_shape = ctx->InputShape(1); + + // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements. + OP_REQUIRES( + ctx, gradients_shape.dims() == 4, + errors::InvalidArgument("orig_input_shape must have 4 elements")); + + // For avgpooling, out_backprop should have 4 dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional")); + + int height_dim = GetTensorDimIndex(data_format_, 'H'); + int width_dim = GetTensorDimIndex(data_format_, 'W'); + int depth = GetTensorDim(out_backprop_shape, data_format_, 'C'); + + // We can think of average-pooling as: + // * a convolution with a kernel consisting entirely of 1s, where the + // input feature and output feature are equal, and 0s everywhere else. + // * followed by dividing by the counts. + // + // This then gives us an algorithm to build the gradient: + // * divide out_backprop by the counts, followed by + // * Conv2DBackpropInput specialized for that kernel, which simplifies to + // a Pad and a ReduceWindow. + // + // For an explanation of backpropagation for convolution, see the comments + // in third_party/tensorflow/core/kernels/conv_grad_ops.h + + // TF filter shape is [ H, W, inC, outC ] + TensorShape filter_shape( + {ksize_[height_dim], ksize_[width_dim], depth, depth}); + + // Reuse the logic from Conv2DBackpropInput to compute padding. + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK( + ctx, Conv2DBackpropComputeDimensions( + "AvgPoolGrad", gradients_shape, filter_shape, + out_backprop_shape, stride_, padding_, data_format_, &dims)); + + auto out_backprop = ctx->Input(1); + + // The input gradients are computed by a convolution of the output + // gradients + // and the filter, with some appropriate padding. See the comment at + // the top of conv_grad_ops.h for details. + DataType dtype = input_type(1); + + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + // Divide the out_backprop values by the counts for each spatial position. + std::vector stride_int64s(stride_.begin(), stride_.end()); + auto out_backprop_div = + AvgPoolDivideByCount(ctx, out_backprop, dtype, gradients_shape, + xla_padding, ksize_, stride_int64s, data_format_); + + // Pad the gradients in the spatial dimensions. We use the same padding + // as Conv2DBackpropInput. + xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(4); + auto* row_padding = padding_config.mutable_dimensions(height_dim); + row_padding->set_edge_padding_low(dims.rows.pad_before); + row_padding->set_edge_padding_high(dims.rows.pad_after); + row_padding->set_interior_padding(dims.rows.stride - 1); + + auto* col_padding = padding_config.mutable_dimensions(width_dim); + col_padding->set_edge_padding_low(dims.cols.pad_before); + col_padding->set_edge_padding_high(dims.cols.pad_after); + col_padding->set_interior_padding(dims.cols.stride - 1); + + auto zero = XlaHelpers::Zero(ctx->builder(), dtype); + auto padded_gradients = + ctx->builder()->Pad(out_backprop_div, zero, padding_config); + + // in_backprop = padded_gradients ones + xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( + padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, + /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kValid); + + ctx->SetOutput(0, in_backprop); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_; +}; + +REGISTER_XLA_OP("AvgPoolGrad", AvgPoolGradOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc new file mode 100644 index 0000000000..4ffe278d1c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA implementations of Random ops +// TODO(misard,phawkins): handle random number generator seeds/states correctly. +// TODO(misard,phawkins): add tests. + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class RandomUniformOp : public XlaOpKernel { + public: + explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + const DataType dtype = output_type(0); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle result = b->RngUniform( + XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + + ctx->SetOutput(0, result); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); +}; + +REGISTER_XLA_OP("RandomUniform", RandomUniformOp); + +class RandomUniformIntOp : public XlaOpKernel { + public: + explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, + TensorShapeToXLAShape(input_type(1), shape, &xla_shape)); + + const TensorShape minval_shape = ctx->InputShape(1); + const TensorShape maxval_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval_shape.DebugString())); + + auto minval = ctx->Input(1); + auto maxval = ctx->Input(2); + ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); +}; + +REGISTER_XLA_OP("RandomUniformInt", RandomUniformIntOp); + +class RandomStandardNormalOp : public XlaOpKernel { + public: + explicit RandomStandardNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const DataType dtype = output_type(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); + + xla::ComputationBuilder* b = ctx->builder(); + + // Normal distribution with a mean of 0 and a standard deviation of 1: + xla::ComputationDataHandle result = b->RngNormal( + XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + + ctx->SetOutput(0, result); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); +}; + +REGISTER_XLA_OP("RandomStandardNormal", RandomStandardNormalOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc new file mode 100644 index 0000000000..ac929af2e2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -0,0 +1,157 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific reduction Ops. + +#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class SumOp : public XlaReductionOp { + public: + explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->Add(scalar_lhs, scalar_rhs); + } +}; + +REGISTER_XLA_OP("Sum", SumOp); + +class ProdOp : public XlaReductionOp { + public: + explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::One(builder, input_type(0)); + } + + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->Mul(scalar_lhs, scalar_rhs); + } +}; + +REGISTER_XLA_OP("Prod", ProdOp); + +class MinOp : public XlaReductionOp { + public: + explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); + return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + } + + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->Min(scalar_lhs, scalar_rhs); + } +}; + +REGISTER_XLA_OP("Min", MinOp); + +class MaxOp : public XlaReductionOp { + public: + explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); + return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + } + + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->Max(scalar_lhs, scalar_rhs); + } +}; + +REGISTER_XLA_OP("Max", MaxOp); + +class MeanOp : public XlaReductionOp { + public: + explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->Add(scalar_lhs, scalar_rhs); + } + + bool BuildFinalizer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_argument, + int64 num_elements_reduced) override { + auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), + num_elements_reduced); + builder->Div(scalar_argument, divisor); + return true; + } +}; + +REGISTER_XLA_OP("Mean", MeanOp); + +class AllOp : public XlaReductionOp { + public: + explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return builder->ConstantR0(true); + } + + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->LogicalAnd(scalar_lhs, scalar_rhs); + } +}; + +REGISTER_XLA_OP("All", AllOp); + +class AnyOp : public XlaReductionOp { + public: + explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return builder->ConstantR0(false); + } + + void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) override { + builder->LogicalOr(scalar_lhs, scalar_rhs); + } +}; + +REGISTER_XLA_OP("Any", AnyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h new file mode 100644 index 0000000000..7f0dd26f91 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific base classes for Reduction Ops. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// Reduction operations. The base class contains pure virtual methods +// to override: description is a textual description of the mapped +// function; InitialValue constructs the base case for the reduction; +// BuildReducer adds the implementation of the reduction lambda to a +// xla::ComputationBuilder and BuildFinalizer adds the +// implementation of the finalizer lambda (if there is one) to a +// xla::ComputationBuilder. +class XlaReductionOp : public XlaOpKernel { + public: + explicit XlaReductionOp(OpKernelConstruction* ctx); + ~XlaReductionOp() override {} + + // Return the base case for the reduction. Defaults to zero. + virtual xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder); + + // Implement the (scalar,scalar)->scalar lambda that should be + // applied to each pair of elements to be reduced. The desired + // computation should be added to 'builder' and + // '(scalar_lhs,scalar_rhs)' are the function's inputs. + virtual void BuildReducer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_lhs, + const xla::ComputationDataHandle& scalar_rhs) = 0; + + // Implement the scalar->scalar lambda that should be applied to + // each element to be finalized. The desired computation should be + // added to 'builder' and 'scalar_argument' is the function's + // input. 'num_elements_reduced' is the number of elements that contributed + // to the reduction. If the reduction has a finalizer return true, otherwise + // return false and any computation added to builder will be + // ignored. Defaults to return false. + virtual bool BuildFinalizer(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_argument, + int64 num_elements_reduced); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + // True if the number of dimensions should be maintained. + bool keep_dims_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc new file mode 100644 index 0000000000..d6b085e897 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific reduction Ops. + +#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { + +XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const DataType dt = BaseType(input_type(0)); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); +} + +// Return the base case for the reduction. Defaults to zero. +xla::ComputationDataHandle XlaReductionOp::InitialValue( + xla::ComputationBuilder* builder) { + return XlaHelpers::Zero(builder, input_type(0)); +} + +// Unless BuildFinalizer is overridden the reduction has no +// finalizer. +bool XlaReductionOp::BuildFinalizer( + xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& scalar_argument, + int64 num_elements_reduced) { + return false; +} + +void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { + const TensorShape data_shape = ctx->InputShape(0); + const TensorShape axes_tensor_shape = ctx->InputShape(1); + VLOG(1) << "ReductionOp: " << ctx->op_kernel().name(); + + if (axes_tensor_shape.num_elements() == 0) { + // The reduction axes is an empty vector, which means there are no + // axes to reduce so just pass the input directly through to the + // output. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + // Evaluate the constant, reshaping to a 1-vector if it is a scalar. + xla::Literal axes_literal; + OP_REQUIRES_OK(ctx, + ctx->ConstantInputReshaped( + 1, {axes_tensor_shape.num_elements()}, &axes_literal)); + + VLOG(1) << "data shape: " << data_shape.DebugString(); + VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal); + + gtl::InlinedVector bitmap(data_shape.dims(), false); + std::vector xla_axes; + int64 num_elements_reduced = 1LL; + for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { + int32 index = xla::LiteralUtil::Get(axes_literal, {i}); + OP_REQUIRES(ctx, + !(index < -data_shape.dims() || index >= data_shape.dims()), + errors::InvalidArgument("Invalid reduction dimension (", index, + " for input with ", data_shape.dims(), + " dimension(s)")); + index = (index + data_shape.dims()) % data_shape.dims(); + bitmap[index] = true; + xla_axes.push_back(index); + num_elements_reduced *= data_shape.dim_size(index); + } + + std::vector final_shape; + for (int i = 0; i < data_shape.dims(); ++i) { + if (!bitmap[i]) { + // If we are not reducing along dimension i. + int64 dim = data_shape.dim_size(i); + final_shape.push_back(dim); + } else if (keep_dims_) { + // We are reducing along dimension i, but we want to keep the + // same number of dimensions, so we set the dimension of i to + // '1'. + final_shape.push_back(1); + } + } + + string desc = ctx->op_kernel().name(); + + // Call virtual method to get the initial value. + const xla::ComputationDataHandle initial = InitialValue(ctx->builder()); + // Construct the builder for the reduction lambda. + xla::ComputationBuilder r(ctx->builder()->client(), + strings::StrCat(desc, "-reduction")); + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); + // Make two scalar parameters of the desired type for the lambda. + xla::ComputationDataHandle rx = + r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + xla::ComputationDataHandle ry = + r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); + + auto data = ctx->Input(0); + + // Call virtual method to build the reduction lambda. + BuildReducer(&r, rx, ry); + xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); + xla::ComputationDataHandle reduce = + ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); + + // Construct the builder for the finalizer lambda. + xla::ComputationBuilder f(ctx->builder()->client(), + strings::StrCat(desc, "-finalizer")); + // Make the scalar parameter of the desired type for the lambda. + xla::ComputationDataHandle fx = + f.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + // Call virtual method to build the finalizer lambda. + bool has_finalizer = BuildFinalizer(&f, fx, num_elements_reduced); + xla::Computation finalizer_computation = f.Build().ConsumeValueOrDie(); + xla::ComputationDataHandle pre_reshaped_data; + if (has_finalizer) { + // This reduction Op includes a finalizer so run it as a Map. + pre_reshaped_data = ctx->builder()->Map({reduce}, finalizer_computation); + } else { + pre_reshaped_data = reduce; + } + + xla::ComputationDataHandle result; + if (keep_dims_) { + result = ctx->builder()->Reshape(pre_reshaped_data, final_shape); + } else { + result = pre_reshaped_data; + } + ctx->SetOutput(0, result); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc new file mode 100644 index 0000000000..3cddff9df4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of XLA Relu Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class ReluOp : public XlaOpKernel { + public: + explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* builder = ctx->builder(); + auto zero = XlaHelpers::Zero(builder, input_type(0)); + ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); + } +}; + +class Relu6Op : public XlaOpKernel { + public: + explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Clamp the scalar input between 0 and 6. + void Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* builder = ctx->builder(); + auto zero = XlaHelpers::Zero(builder, input_type(0)); + auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); + ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); + } +}; + +// A subclass of a XlaBinaryMapOp must build the lambda computation +// that describes the (scalar,scalar)->scalar function to apply to +// each element of the input. We have to use XlaBinaryMapOp instead of +// XlaBinaryOp here because XLA Select does not do automatic +// broadcasting. +class ReluGradOp : public XlaBinaryMapOp { + public: + explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return 0. + void BuildMapLambda(xla::ComputationBuilder* b, + const xla::ComputationDataHandle& gradient, + const xla::ComputationDataHandle& feature) override { + const auto zero = XlaHelpers::Zero(b, input_type(0)); + b->Select(b->Gt(feature, zero), gradient, zero); + } +}; + +class Relu6GradOp : public XlaBinaryMapOp { + public: + explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return 0. + void BuildMapLambda(xla::ComputationBuilder* b, + const xla::ComputationDataHandle& gradient, + const xla::ComputationDataHandle& feature) override { + const auto zero = XlaHelpers::Zero(b, input_type(0)); + auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6); + b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)), + gradient, zero); + } +}; + +REGISTER_XLA_OP("Relu", ReluOp); +REGISTER_XLA_OP("Relu6", Relu6Op); +REGISTER_XLA_OP("ReluGrad", ReluGradOp); +REGISTER_XLA_OP("Relu6Grad", Relu6GradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc new file mode 100644 index 0000000000..febce0e126 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific reshape Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +class ReshapeOp : public XlaOpKernel { + public: + explicit ReshapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape sizes_shape = ctx->InputShape(1); + // Preliminary validation of sizes. + OP_REQUIRES(ctx, IsLegacyVector(sizes_shape), + errors::InvalidArgument("sizes input must be 1-D, not shape ", + sizes_shape.DebugString())); + const int64 num_dims = sizes_shape.num_elements(); + + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + + // Compute the output shape. Determine product of specified + // dimensions, and find the index of the unspecified one if there + // is one. + TensorShape shape; + int64 product = 1; + int unknown_index = -1; + for (int d = 0; d < num_dims; ++d) { + const int32 size = xla::LiteralUtil::Get(literal, {d}); + if (size == -1) { + OP_REQUIRES( + ctx, unknown_index == -1, + errors::InvalidArgument("only one input size may be -1, not both ", + unknown_index, " and ", d)); + unknown_index = d; + shape.AddDim(1); + } else { + OP_REQUIRES(ctx, size >= 0, + errors::InvalidArgument( + "size ", d, " must be non-negative, not ", size)); + shape.AddDim(size); + product *= size; + } + } + if (unknown_index != -1) { + OP_REQUIRES( + ctx, product > 0, + errors::InvalidArgument("Reshape cannot infer the missing input size " + "for an empty tensor unless all specified " + "input sizes are non-zero")); + const int64 missing = input_shape.num_elements() / product; + OP_REQUIRES( + ctx, product * missing == input_shape.num_elements(), + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_shape.num_elements(), + " values, but the requested shape requires a multiple of ", + product)); + shape.set_dim(unknown_index, missing); + } + OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), + errors::InvalidArgument("Input to reshape is a tensor with ", + input_shape.num_elements(), + " values, but the requested shape has ", + shape.num_elements())); + + VLOG(1) << "Reshape " << input_shape.DebugString() << " " + << shape.DebugString(); + + ctx->SetOutput(0, + ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP("Reshape", ReshapeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc new file mode 100644 index 0000000000..87d11a38d4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +// This TensorFlow op indicates that its input should be treated as a +// specific return value from a function. +class RetvalOp : public XlaOpKernel { + public: + explicit RetvalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const Tensor& input = ctx->op_kernel_context()->input(0); + + OP_REQUIRES(ctx, input.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(input.dtype()), + " vs. expect ", DataTypeString(dtype_))); + auto frame = ctx->call_frame(); + if (frame) { + // If 'frame' is non-null, this is an inner function call inside a JIT + // compilation. + frame->SetRetval(index_, input); + } else { + xla::ComputationDataHandle input = ctx->Input(0); + const TensorShape input_shape = ctx->InputShape(0); + + auto is_constant = ctx->builder()->IsConstant(input); + if (!is_constant.ok()) { + ctx->SetStatus(is_constant.status()); + return; + } + + XlaContext& tc = XlaContext::Get(ctx); + if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); + tc.AddConstRetval(index_, dtype_, literal); + } else { + tc.AddRetval(index_, input); + } + } + } + + private: + // The index of this return value in the returned tuple. + int index_; + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); +}; + +REGISTER_XLA_OP("_Retval", RetvalOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc new file mode 100644 index 0000000000..0fecc338ca --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { +namespace { + +class SelectOp : public XlaOpKernel { + public: + explicit SelectOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape cond_shape = ctx->InputShape(0); + const TensorShape then_shape = ctx->InputShape(1); + const TensorShape else_shape = ctx->InputShape(2); + + OP_REQUIRES( + ctx, then_shape.IsSameSize(else_shape), + errors::InvalidArgument( + "'then' and 'else' must have the same size. but received: ", + then_shape.DebugString(), " vs. ", else_shape.DebugString())); + + bool broadcasting = !cond_shape.IsSameSize(then_shape); + if (broadcasting) { + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(cond_shape), + errors::InvalidArgument("'cond' must be a vector, but saw shape: ", + cond_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then_shape), + errors::InvalidArgument( + "'then' must be at least a vector, but saw shape: ", + then_shape.DebugString())); + OP_REQUIRES(ctx, then_shape.dim_size(0) == cond_shape.num_elements(), + errors::InvalidArgument("Number of batches of 'then' must " + "match size of 'cond', but saw: ", + then_shape.dim_size(0), " vs. ", + cond_shape.num_elements())); + } + + xla::ComputationBuilder* builder = ctx->builder(); + + auto cond_handle = ctx->Input(0); + auto then_handle = ctx->Input(1); + auto else_handle = ctx->Input(2); + + if (broadcasting) { + // TODO(phawkins): broadcasting on the right seems pretty awkward in + // XLA. It seems we have to broadcast on the left and then Reshape + // to get the dimensions in the right order. + const auto dim_sizes = then_shape.dim_sizes(); + gtl::ArraySlice bdims = dim_sizes; + bdims.pop_front(); + cond_handle = builder->Broadcast(cond_handle, bdims); + + std::vector dim_order(then_shape.dims()); + dim_order[0] = then_shape.dims() - 1; + std::iota(dim_order.begin() + 1, dim_order.end(), 0); + cond_handle = builder->Transpose(cond_handle, dim_order); + } + ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); +}; + +REGISTER_XLA_OP("Select", SelectOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc new file mode 100644 index 0000000000..42ae978c3c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific sequence and range Ops. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace { + +template +Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); + *value = xla::LiteralUtil::Get(literal, {}); + return Status::OK(); +} + +Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); + switch (literal.shape().element_type()) { + case xla::S32: + *value = xla::LiteralUtil::Get(literal, {}); + break; + case xla::S64: + *value = xla::LiteralUtil::Get(literal, {}); + break; + default: + return errors::InvalidArgument("Invalid argument type for argument", + index); + } + return Status::OK(); +} + +// The type-specific part of the implementation of Range. +template +Status CreateRangeTensor(const xla::Literal& start_literal, + const xla::Literal& limit_literal, + const xla::Literal& delta_literal, Tensor* output) { + T start = xla::LiteralUtil::Get(start_literal, {}); + T limit = xla::LiteralUtil::Get(limit_literal, {}); + T delta = xla::LiteralUtil::Get(delta_literal, {}); + + if (delta == 0) { + return errors::InvalidArgument("Requires delta != 0: ", delta); + } + if (delta > 0) { + if (start > limit) { + return errors::InvalidArgument("Requires start <= limit when delta > 0: ", + start, "/", limit); + } + } else { + if (start < limit) { + return errors::InvalidArgument("Requires start >= limit when delta < 0: ", + start, "/", limit); + } + } + int64 size = + (std::is_integral::value + ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) + : std::ceil(std::abs((limit - start) / delta))); + + *output = Tensor(DataTypeToEnum::v(), TensorShape({size})); + auto flat = output->flat(); + T val = start; + for (int64 i = 0; i < size; ++i) { + flat(i) = val; + val += delta; + } + return Status::OK(); +} + +class RangeOp : public XlaOpKernel { + public: + explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape start_in_shape = ctx->InputShape(0); + const TensorShape limit_in_shape = ctx->InputShape(1); + const TensorShape delta_in_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), + errors::InvalidArgument("start must be a scalar, not shape ", + start_in_shape.DebugString())); + OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), + errors::InvalidArgument("limit must be a scalar, not shape ", + limit_in_shape.DebugString())); + OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), + errors::InvalidArgument("delta must be a scalar, not shape ", + delta_in_shape.DebugString())); + xla::Literal start, limit, delta; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); + + DataType type = input_type(0); + Tensor output; + Status status; + switch (type) { + case DT_INT32: + status = CreateRangeTensor(start, limit, delta, &output); + break; + case DT_INT64: + status = CreateRangeTensor(start, limit, delta, &output); + break; + case DT_FLOAT: + status = CreateRangeTensor(start, limit, delta, &output); + break; + case DT_DOUBLE: + status = CreateRangeTensor(start, limit, delta, &output); + break; + default: + status = errors::InvalidArgument("Invalid type for Range ", + DataTypeString(type)); + } + OP_REQUIRES_OK(ctx, status); + ctx->SetConstantOutput(0, output); + } +}; + +REGISTER_XLA_OP("Range", RangeOp); + +class LinSpaceOp : public XlaOpKernel { + public: + explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape start_in_shape = ctx->InputShape(0); + const TensorShape stop_in_shape = ctx->InputShape(1); + const TensorShape num_in_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), + errors::InvalidArgument("start must be a scalar, not shape ", + start_in_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape), + errors::InvalidArgument("stop must be a scalar, not shape ", + stop_in_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape), + errors::InvalidArgument("num must be a scalar, not shape ", + num_in_shape.DebugString())); + + DataType type = ctx->input_type(0); + + int64 num; + OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); + OP_REQUIRES(ctx, num > 0, + errors::InvalidArgument("Requires num > 0: ", num)); + Tensor out_constant(type, TensorShape({num})); + + switch (type) { + case DT_FLOAT: { + float start, stop; + OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); + OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + auto flat = out_constant.flat(); + if (num == 1) { + flat(0) = start; + } else { + const float step = (stop - start) / (num - 1); + for (int64 i = 0; i < num; ++i) { + flat(i) = start + step * i; + } + } + break; + } + case DT_DOUBLE: { + double start, stop; + OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); + OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + auto flat = out_constant.flat(); + if (num == 1) { + flat(0) = start; + } else { + const double step = (stop - start) / (num - 1); + for (int64 i = 0; i < num; ++i) { + flat(i) = start + step * i; + } + } + break; + } + + default: + ctx->SetStatus(errors::InvalidArgument("Invalid argument type ", + DataTypeString(type))); + return; + } + ctx->SetConstantOutput(0, out_constant); + } +}; + +REGISTER_XLA_OP("LinSpace", LinSpaceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc new file mode 100644 index 0000000000..e7eec1cefd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Shape Ops. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { +namespace { + +class ShapeOp : public XlaOpKernel { + public: + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const int rank = input_shape.dims(); + Tensor shape_constant(DT_INT32, TensorShape({rank})); + auto vec = shape_constant.vec(); + // TODO(dga): support int64. b/28119922. + for (int i = 0; i < rank; ++i) { + int64 dim_size = input_shape.dim_size(i); + OP_REQUIRES( + ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), + errors::InvalidArgument("Shape does not support tensors > int32max", + " but dim ", i, " is ", dim_size)); + vec(i) = static_cast(dim_size); + } + + ctx->SetConstantOutput(0, shape_constant); + } +}; + +REGISTER_XLA_OP("Shape", ShapeOp); + +class ShapeNOp : public XlaOpKernel { + public: + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + for (int i = 0; i < ctx->num_inputs(); ++i) { + const TensorShape shape = ctx->InputShape(i); + const int dims = shape.dims(); + Tensor shape_constant(DT_INT32, TensorShape({dims})); + auto vec = shape_constant.vec(); + + // TODO(dga): support int64. b/28119922. + for (int j = 0; j < dims; ++j) { + int64 dim_size = shape.dim_size(j); + OP_REQUIRES( + ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), + errors::InvalidArgument("Shape does not support tensors > int32max", + " but shape ", i, " dim ", j, " is ", + dim_size)); + vec(j) = static_cast(dim_size); + } + + ctx->SetConstantOutput(i, shape_constant); + } + } + + bool IsExpensive() override { return false; } +}; +REGISTER_XLA_OP("ShapeN", ShapeNOp); + +class RankOp : public XlaOpKernel { + public: + explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const int rank = input_shape.dims(); + Tensor rank_constant(DT_INT32, TensorShape({})); + rank_constant.scalar()() = rank; + + ctx->SetConstantOutput(0, rank_constant); + } +}; + +REGISTER_XLA_OP("Rank", RankOp); + +class SizeOp : public XlaOpKernel { + public: + explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const int64 size = input_shape.num_elements(); + OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits::max()), + errors::InvalidArgument("Size does not work for tensors > " + "int32 max.")); + Tensor size_constant(DT_INT32, TensorShape({})); + size_constant.scalar()() = static_cast(size); + + ctx->SetConstantOutput(0, size_constant); + } +}; + +REGISTER_XLA_OP("Size", SizeOp); + +class ExpandDimsOp : public XlaOpKernel { + public: + explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape dim_shape = ctx->InputShape(1); + + // TODO(phawkins): the standard implementation of ExpandDimsOp seems to + // accept legacy scalars, even when they should be forbidden by the graphdef + // version. + OP_REQUIRES(ctx, dim_shape.num_elements() == 1, + errors::InvalidArgument(strings::StrCat( + "dim input to ExpandDims must be a scalar; got ", + dim_shape.DebugString()))); + + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); + + int dim = literal.s32s(0); + + OP_REQUIRES(ctx, + (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), + errors::InvalidArgument("Tried to expand dim index ", dim, + " for tensor with ", input_shape.dims(), + " dimensions.")); + + auto existing_dims = input_shape.dim_sizes(); + // Safe - # elements in tensor dims bounded. + const int existing_dims_size = static_cast(existing_dims.size()); + std::vector new_shape(existing_dims_size); + for (size_t i = 0; i < new_shape.size(); ++i) { + new_shape[i] = existing_dims[i]; + } + + // We emulate numpy's interpretation of the dim axis when + // -input.dims() >= dim <= input.dims(). + if (dim < 0) { + dim += existing_dims.size() + 1; + } + + // Clamp to the end if needed. + dim = std::min(dim, existing_dims_size); + new_shape.emplace(new_shape.begin() + dim, 1); + + ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + } +}; +REGISTER_XLA_OP("ExpandDims", ExpandDimsOp); + +class SqueezeOp : public XlaOpKernel { + public: + explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + std::vector squeeze_dims; + OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); + squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + auto existing_dims = input_shape.dim_sizes(); + int existing_dims_size = input_shape.dims(); + std::vector new_shape; + + std::unordered_set wrapped_squeeze_dims; + wrapped_squeeze_dims.reserve(squeeze_dims_.size()); + // Validate squeeze dims against the input. + for (int32 dim : squeeze_dims_) { + OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()), + errors::InvalidArgument("Tried to squeeze dim index ", dim, + " for tensor with ", + input_shape.dims(), " dimensions.")); + // If dim is < 0, we wrap around (-1 means the last element). + if (dim < 0) { + dim = existing_dims_size + dim; + } + + wrapped_squeeze_dims.insert(dim); + } + + for (int i = 0; i < existing_dims_size; ++i) { + auto existing_dim = existing_dims[i]; + + // If squeeze_set is non-empty, only squeeze those dimensions. + if (!wrapped_squeeze_dims.empty()) { + if (wrapped_squeeze_dims.count(i) > 0) { + OP_REQUIRES(ctx, existing_dim == 1, + errors::InvalidArgument("Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", + existing_dim)); + } else { + // This dimension is not being squeezed. + new_shape.push_back(existing_dim); + } + } else { + // Copy over all non-1-length dimensions. + if (existing_dim != 1) { + new_shape.push_back(existing_dim); + } + } + } + + ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + } + + private: + std::unordered_set squeeze_dims_; +}; + +REGISTER_XLA_OP("Squeeze", SqueezeOp); + +class ZerosLikeOp : public XlaOpKernel { + public: + explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + + auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); + ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP("ZerosLike", ZerosLikeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc new file mode 100644 index 0000000000..8ec77e04af --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Slice Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mem.h" + +namespace tensorflow { +namespace { + +class SliceOp : public XlaOpKernel { + public: + explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + bool is_identity = true; + std::vector begin; + std::vector size; + SharedValidation(ctx, &is_identity, &begin, &size); + if (!ctx->status().ok()) return; + + if (is_identity) { + VLOG(1) << "Slice identity"; + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + // slice will be an empty handle if the output has no elements. + CHECK_EQ(begin.size(), size.size()); + std::vector limits; + for (int i = 0; i < begin.size(); ++i) { + limits.push_back(begin[i] + size[i]); + } + ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits)); + } + + private: + void SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, + std::vector* begin, std::vector* size); +}; + +void SliceOp::SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, + std::vector* begin, + std::vector* size) { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape begin_tensor_shape = ctx->InputShape(1); + const TensorShape size_tensor_shape = ctx->InputShape(2); + + OP_REQUIRES( + ctx, + IsLegacyVector(begin_tensor_shape) && IsLegacyVector(size_tensor_shape) && + begin_tensor_shape.num_elements() == input_shape.dims() && + size_tensor_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "Expected begin and size arguments to be 1-D tensors of size ", + input_shape.dims(), ", but got shapes ", + begin_tensor_shape.DebugString(), " and ", + size_tensor_shape.DebugString(), " instead.")); + + const int input_dims = input_shape.dims(); + + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, begin)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, size)); + for (int i = 0; i < input_dims; ++i) { + if ((*size)[i] == -1) { + // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". + (*size)[i] = input_shape.dim_size(i) - (*begin)[i]; + } + } + + *is_identity = true; + for (int i = 0; i < input_dims; ++i) { + int64 b = (*begin)[i]; + int64 s = (*size)[i]; + if (input_shape.dim_size(i) == 0) { + OP_REQUIRES(ctx, b == 0 && s == 0, + errors::InvalidArgument( + "Expected begin[", i, "] == 0 (got ", b, ") and size[", i, + "] == 0 ", "(got ", s, ") when ", "input_shape.dim_size(", + i, ") == 0")); + } else { + OP_REQUIRES( + ctx, 0 <= b && b <= input_shape.dim_size(i), + errors::InvalidArgument("Expected begin[", i, "] in [0, ", + input_shape.dim_size(i), "], but got ", b)); + OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_shape.dim_size(i) - b, + "], but ", "got ", s)); + } + const bool take_all = (b == 0) && (s == input_shape.dim_size(i)); + (*is_identity) &= take_all; + } +} + +REGISTER_XLA_OP("Slice", SliceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc new file mode 100644 index 0000000000..06ee520163 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -0,0 +1,152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for softmax. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class SoftmaxOp : public XlaOpKernel { + public: + explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + log_ = StringPiece(type_string()).starts_with("Log"); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape logits_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), + errors::InvalidArgument("logits must be 2-dimensional")); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const DataType type = input_type(0); + auto logits = ctx->Input(0); + + xla::ComputationBuilder* b = ctx->builder(); + const xla::Computation& max_func = *ctx->GetOrCreateMax(type); + const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); + + // Find the max in each batch, resulting in a tensor of shape [batch] + auto logits_max = + b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + // Subtract the max in batch b from every element in batch b. Broadcasts + // along the batch dimension. + auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + xla::ComputationDataHandle softmax; + if (log_) { + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + auto log_sum_exp = + b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type), + add_func, {kClassDim})); + softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); + } else { + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + auto exp_shifted = b->Exp(shifted_logits); + auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func, + {kClassDim}); + softmax = b->Div(exp_shifted, sum_exp, {kBatchDim}); + } + + ctx->SetOutput(0, softmax); + } + + private: + bool log_; +}; + +REGISTER_XLA_OP("Softmax", SoftmaxOp); +REGISTER_XLA_OP("LogSoftmax", SoftmaxOp); + +class SoftmaxXentWithLogitsOp : public XlaOpKernel { + public: + explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape logits_shape = ctx->InputShape(0); + const TensorShape labels_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape), + errors::InvalidArgument( + "logits and labels must be same size: logits_size=", + logits_shape.DebugString(), " labels_size=", + labels_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), + errors::InvalidArgument("logits must be 2-dimensional")); + // As we already tested that both inputs have the same shape no need to + // check that "labels" is a matrix too. + + // loss is 1-D (one per example), and size is batch_size. + + const int kBatchDim = 0; + const int kClassDim = 1; + + const DataType type = input_type(0); + xla::ComputationBuilder* b = ctx->builder(); + auto logits = ctx->Input(0); + auto labels = ctx->Input(1); + + const xla::Computation& max_func = *ctx->GetOrCreateMax(type); + const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); + + // Find the max in each batch, resulting in a tensor of shape [batch] + auto logits_max = + b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + + // Subtract the max in batch b from every element in batch b. + // Broadcasts along the batch dimension. + auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + + // exp(logits - max_logits) + auto exp_shifted_logits = b->Exp(shifted_logits); + + // sum_{class} (exp(logits - max_logits)) + auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), + add_func, {kClassDim}); + + // log(sum(exp(logits - max_logits))) + auto log_sum_exp = b->Log(sum_exp); + + // sum(-labels * + // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) + // along classes + // (The subtraction broadcasts along the batch dimension.) + xla::ComputationDataHandle loss = + b->Reduce(b->Mul(b->Neg(labels), + b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), + XlaHelpers::Zero(b, type), add_func, {kClassDim}); + + // backprop: prob - labels, where + // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) + // (where the division broadcasts along the batch dimension) + xla::ComputationDataHandle backprop = + b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); + + ctx->SetOutput(0, loss); + ctx->SetOutput(1, backprop); + } +}; + +REGISTER_XLA_OP("SoftmaxCrossEntropyWithLogits", SoftmaxXentWithLogitsOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc new file mode 100644 index 0000000000..18c4c648db --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for split. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class SplitOp : public XlaOpKernel { + public: + explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape index_shape = ctx->InputShape(0); + xla::Literal literal_index; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); + + int32 split_dim; + if (index_shape.dims() == 0) { + split_dim = xla::LiteralUtil::Get(literal_index, {}); + } else { + OP_REQUIRES( + ctx, index_shape.dims() == 1, + errors::InvalidArgument("split_index input to Split Op must be a " + "scalar or a vector with 1 element")); + OP_REQUIRES( + ctx, index_shape.dim_size(0) == 1, + errors::InvalidArgument("split_index input to Split Op must be a " + "scalar or a vector with 1 element")); + split_dim = xla::LiteralUtil::Get(literal_index, {0}); + } + const int32 num_split = num_outputs(); + const TensorShape input_shape = ctx->InputShape(1); + + OP_REQUIRES( + ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("0 <= split_dim < number of input dimensions (", + input_shape.dims(), "), but got ", split_dim)); + + OP_REQUIRES( + ctx, num_split > 0, + errors::InvalidArgument( + "Number of ways to split should be > 0, but got ", num_split)); + + OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0, + errors::InvalidArgument( + "Number of ways to split should evenly divide the split " + "dimension, but got split_dim ", + split_dim, " (size = ", input_shape.dim_size(split_dim), + ") ", "and num_split ", num_split)); + + // All the slices are the same size: this is the size along the + // split dimension. + const int32 slice_size = input_shape.dim_size(split_dim) / num_split; + + // The vectors we will use to define the slice. The entry for the + // split dimensions varies for each output. + std::vector begin; + std::vector limits; + for (int i = 0; i < input_shape.dims(); ++i) { + // Initially set up the limits to be the full size of the input: + // the split dimension is filled in below. + int64 dim = input_shape.dim_size(i); + begin.push_back(0); + limits.push_back(dim); + } + + auto input = ctx->Input(1); + + // Create each of the outputs. + for (int i = 0; i < num_split; ++i) { + // Slice out the ith split from the split dimension. + begin[split_dim] = i * slice_size; + limits[split_dim] = (i + 1) * slice_size; + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + } + } +}; + +REGISTER_XLA_OP("Split", SplitOp); + +class SplitVOp : public XlaOpKernel { + public: + explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const int32 num_split = num_outputs(); + const TensorShape index_shape = ctx->InputShape(2); + xla::Literal literal_index; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index)); + + int32 split_dim; + OP_REQUIRES(ctx, index_shape.dims() == 0, + errors::InvalidArgument("split_dim input to Split Op must be a " + "scalar")); + split_dim = xla::LiteralUtil::Get(literal_index, {}); + + xla::ComputationDataHandle input = ctx->Input(0); + const TensorShape input_shape = ctx->InputShape(0); + + OP_REQUIRES(ctx, input_shape.dims() > 0, + errors::InvalidArgument("Can't split a 0 dimensional input")); + + OP_REQUIRES( + ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("0 <= split_dim < number of input dimensions (", + input_shape.dims(), "), but got ", split_dim)); + + OP_REQUIRES( + ctx, num_split > 0, + errors::InvalidArgument( + "Number of ways to split should be > 0, but got ", num_split)); + + // check that sizes are correct + int total_split_size = 0; + int neg_one_dim = -1; + std::vector split_sizes_vec(num_split, -1); + const TensorShape split_size_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, split_size_shape.dims() == 1 && + split_size_shape.num_elements() == num_split, + errors::InvalidArgument( + "shape of tensor describing " + " the output must have dimension 1 and the same " + " number of elements as the output. Got ", + split_size_shape.dims(), "-D and ", + split_size_shape.num_elements(), " elements")); + // get the dimension of this split + xla::Literal split_size_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); + + for (int i = 0; i < num_split; ++i) { + int slice_size; + slice_size = xla::LiteralUtil::Get(split_size_literal, {i}); + if (slice_size == -1) { + OP_REQUIRES( + ctx, neg_one_dim == -1, + errors::InvalidArgument("Only one dimensions can have a value of" + "-1. Second one found at dimension ", + i)); + neg_one_dim = i; + } else { + split_sizes_vec[i] = slice_size; + total_split_size += slice_size; + } + } + + OP_REQUIRES( + ctx, (neg_one_dim == -1 && + total_split_size == input_shape.dim_size(split_dim)) || + (neg_one_dim >= 0 && + total_split_size <= input_shape.dim_size(split_dim)), + errors::InvalidArgument("Determined shape must either match " + "input shape along split_dim exactly if " + "fully specified, or be less than the size of " + "the input along split_dim if not fully " + "specified. Got: ", + total_split_size)); + + if (neg_one_dim >= 0) { + split_sizes_vec[neg_one_dim] = + input_shape.dim_size(split_dim) - total_split_size; + } + + // The vectors we will use to define the slice. The entry for the + // split dimensions varies for each output. + std::vector begin(input_shape.dims(), 0); + auto dim_sizes = input_shape.dim_sizes(); + std::vector limits(dim_sizes.begin(), dim_sizes.end()); + + for (int i = 0; i < num_split; ++i) { + TensorShape output_shape(input_shape); + int slice_size = split_sizes_vec[i]; + output_shape.set_dim(split_dim, slice_size); + + // Slice out the ith split from the split dimension. + limits[split_dim] = begin[split_dim] + slice_size; + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + begin[split_dim] = limits[split_dim]; + } + } +}; + +REGISTER_XLA_OP("SplitV", SplitVOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc new file mode 100644 index 0000000000..83bf24814f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/strided_slice_op.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mem.h" + +namespace tensorflow { +namespace { + +class StridedSliceOp : public XlaOpKernel { + public: + explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + + TensorShape final_shape; + gtl::InlinedVector begin; + gtl::InlinedVector end; + gtl::InlinedVector strides; + + xla::Literal begin_literal, end_literal, strides_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + + Tensor begin_tensor, end_tensor, strides_tensor; + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); + OP_REQUIRES_OK(ctx, + LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + + TensorShape dummy_processing_shape; + ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); + ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( + &dummy_processing_shape); + bool dummy = false; + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + ShapeReadWriteFromTensorShape(&input_shape), begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, + &dummy, &dummy, &begin, &end, &strides)); + + gtl::InlinedVector dimensions_to_reverse; + gtl::InlinedVector slice_begin, slice_end; + for (int i = 0; i < begin.size(); ++i) { + // TODO(phawkins): implement strides != 1 when b/30878775 is fixed. + OP_REQUIRES( + ctx, strides[i] == 1 || strides[i] == -1, + errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + if (strides[i] > 0) { + slice_begin.push_back(begin[i]); + slice_end.push_back(end[i]); + } else { + // Negative stride: swap begin and end, add 1 because the interval + // is semi-open, and mark the dimension to be reversed. + slice_begin.push_back(end[i] + 1); + slice_end.push_back(begin[i] + 1); + dimensions_to_reverse.push_back(i); + } + } + xla::ComputationDataHandle slice = + ctx->builder()->Slice(ctx->Input(0), slice_begin, slice_end); + if (!dimensions_to_reverse.empty()) { + slice = ctx->builder()->Rev(slice, dimensions_to_reverse); + } + + slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); + } + + private: + int32 begin_mask_, end_mask_; + int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + DataType index_type_; +}; + +REGISTER_XLA_OP("StridedSlice", StridedSliceOp); + +class StridedSliceGradOp : public XlaOpKernel { + public: + explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape processing_shape, final_shape; + gtl::InlinedVector begin; + gtl::InlinedVector end; + gtl::InlinedVector strides; + + TensorShape input_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); + + xla::Literal begin_literal, end_literal, strides_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + + Tensor begin_tensor, end_tensor, strides_tensor; + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); + OP_REQUIRES_OK(ctx, + LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + + bool dummy = false; + ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); + ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape); + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + ShapeReadWriteFromTensorShape(&input_shape), begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &wrapped_processing_shape, &wrapped_final_shape, &dummy, + &dummy, &dummy, &begin, &end, &strides)); + + // Check to make sure dy is consistent with the original slice + const TensorShape dy_shape = ctx->InputShape(4); + OP_REQUIRES( + ctx, final_shape == dy_shape, + errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(), + " instead of ", final_shape.DebugString())); + + OP_REQUIRES( + ctx, input_shape.dims() == processing_shape.dims(), + errors::Internal( + "input shape and processing shape must have same number of dims")); + + auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); + + xla::ComputationDataHandle grad = ctx->Input(4); + + // Undo any new/shrink axes. + grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); + + // Pad the input gradients. + gtl::InlinedVector dimensions_to_reverse; + xla::PaddingConfig padding_config; + + for (int i = 0; i < processing_shape.dims(); ++i) { + auto* dims = padding_config.add_dimensions(); + if (strides[i] > 0) { + dims->set_edge_padding_low(begin[i]); + dims->set_interior_padding(strides[i] - 1); + + // Pad the upper dimension up to the expected input shape. (It's + // not sufficient simply to use "end[i]" to compute the padding in + // cases where the stride does not divide evenly into the interval + // between begin[i] and end[i].) + int64 size = + dims->edge_padding_low() + processing_shape.dim_size(i) + + (processing_shape.dim_size(i) - 1) * dims->interior_padding(); + dims->set_edge_padding_high(input_shape.dim_size(i) - size); + } else { + dimensions_to_reverse.push_back(i); + dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1); + dims->set_interior_padding(-strides[i] - 1); + + // Pad the lower dimension up to the expected input shape. + int64 size = + dims->edge_padding_high() + processing_shape.dim_size(i) + + (processing_shape.dim_size(i) - 1) * dims->interior_padding(); + dims->set_edge_padding_low(input_shape.dim_size(i) - size); + } + } + if (!dimensions_to_reverse.empty()) { + grad = ctx->builder()->Rev(grad, dimensions_to_reverse); + } + grad = ctx->builder()->Pad(grad, zero, padding_config); + ctx->SetOutput(0, grad); + } + + private: + int32 begin_mask_, end_mask_; + int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + DataType index_type_; +}; + +REGISTER_XLA_OP("StridedSliceGrad", StridedSliceGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc new file mode 100644 index 0000000000..45ac5e12c7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -0,0 +1,128 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Tile Op. + +#include +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +// -------------------------------------------------------------------------- +class TileOp : public XlaOpKernel { + public: + explicit TileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape multiples_shape = ctx->InputShape(1); + + OP_REQUIRES( + ctx, IsLegacyVector(multiples_shape), + errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", + multiples_shape.DebugString())); + OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(), + errors::InvalidArgument( + "Expected multiples argument to be a vector of length ", + input_shape.dims(), " but got length ", + multiples_shape.dim_size(0))); + const int input_dims = input_shape.dims(); + + // If input is a scalar then multiples has 0 elements and this is + // a NoOp. + if (input_dims == 0) { + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + + // zero_element_result is true if the final shape has 0 elements, + // i.e. if any of the input dimensions or multiples is zero. + std::vector multiples_array(input_dims); + std::vector output_shape; + bool all_multiples_are_one = true; + bool one_dimension_is_broadcasted_without_multiple = true; + for (int i = 0; i < input_dims; ++i) { + int multiple = xla::LiteralUtil::Get(literal, {i}); + OP_REQUIRES(ctx, multiple, + errors::InvalidArgument("Expected multiples[", i, + "] >= 0, but got ", multiple)); + int64 new_dim = input_shape.dim_size(i) * multiple; + output_shape.push_back(new_dim); + multiples_array[i] = multiple; + all_multiples_are_one = all_multiples_are_one && multiple == 1; + // If the multiple of a non-one dimensions is not one, then binary + // operation broadcast semantics will not be sufficient to implement the + // tile operation. + one_dimension_is_broadcasted_without_multiple = + one_dimension_is_broadcasted_without_multiple && + ((input_shape.dim_size(i) > 1 && multiple == 1) || + input_shape.dim_size(i) == 1); + } + auto input = ctx->Input(0); + // If all multiples are 1, than the input is the same as the output. + if (all_multiples_are_one) { + ctx->SetOutput(0, input); + return; + } + if (one_dimension_is_broadcasted_without_multiple) { + // Create a constant Zero the size of the output shape to leverage binary + // operation broadcast semantics. + auto broadcasted_zero = ctx->builder()->Broadcast( + XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); + ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input)); + return; + } + + // First broadcast the requisite number of multiples along each + // dimension. This prepends the broadcasted dimensions, so an + // input of shape [2,3,1] broadcast with multiples [5,4,3] will + // end up with shape [5,4,3,2,3,1]. + auto broadcasted = ctx->builder()->Broadcast(input, multiples_array); + // Now flatten and reshape. The broadcasted dimensions are + // paired with the original dimensions so in the above example + // we flatten [0,3,1,4,2,5] then reshape to [10,12,3]. + std::vector flattened; + for (int i = 0; i < output_shape.size(); ++i) { + flattened.push_back(i); + flattened.push_back(i + output_shape.size()); + } + xla::ComputationDataHandle output = + ctx->builder()->Reshape(broadcasted, flattened, output_shape); + + ctx->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TileOp); +}; + +REGISTER_XLA_OP("Tile", TileOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc new file mode 100644 index 0000000000..2840abc878 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Transpose Op. This is very different to the Eigen +// version in third_party/tensorflow because XLA's reshape neatly +// handles all transposes, while Eigen needs a restricted DoTranspose +// helper. + +#include "tensorflow/core/kernels/transpose_op.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { +namespace { + +class TransposeOp : public XlaOpKernel { + public: + explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape perm_tensor_shape = ctx->InputShape(1); + + // Preliminary validation of sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), + errors::InvalidArgument("perm must be a vector, not ", + perm_tensor_shape.DebugString())); + + const int dims = input_shape.dims(); + OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(), + errors::InvalidArgument("transpose expects a vector of size ", + input_shape.dims(), + ". But input(1) is a vector of size ", + perm_tensor_shape.num_elements())); + + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); + + std::vector perm(dims); + std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin()); + + std::vector transposed_order; + // Check whether permutation is a permutation of integers of [0 .. dims). + gtl::InlinedVector bits(dims); + bool is_identity = true; + for (int i = 0; i < dims; ++i) { + const int32 d = perm[i]; + OP_REQUIRES( + ctx, 0 <= d && d < dims, + errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); + bits[d] = true; + transposed_order.push_back(d); + if (d != i) { + is_identity = false; + } + } + for (int i = 0; i < dims; ++i) { + OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( + i, " is missing from 'perm' argument.")); + } + + // 0-D, 1-D, and identity transposes do nothing. + if (dims <= 1 || is_identity) { + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + ctx->SetOutput(0, + ctx->builder()->Transpose(ctx->Input(0), transposed_order)); + } +}; + +REGISTER_XLA_OP("Transpose", TransposeOp); + +// InvertPermutation frequently forms part of the gradient of Transpose. +// +// inv = InvertPermutationOp(T p) takes a permutation of +// integers 0, 1, ..., n - 1 and returns the inverted +// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). +// +// REQUIRES: input is a vector of int32. +// REQUIRES: input is a permutation of 0, 1, ..., n-1. + +class InvertPermutationOp : public XlaOpKernel { + public: + explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(), + std::numeric_limits::max()), + errors::InvalidArgument("permutation of nonnegative int32s " + "must have <= int32 max elements")); + + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); + + int size = perm.size(); + + std::vector output(size); + std::fill_n(output.data(), size, -1); + for (int i = 0; i < size; ++i) { + const int64 d = perm[i]; + OP_REQUIRES(ctx, FastBoundsCheck(d, size), + errors::InvalidArgument(d, " is not between 0 and ", size)); + OP_REQUIRES(ctx, output[d] == -1, + errors::InvalidArgument(d, " is duplicated in the input.")); + output[d] = i; + } + + ctx->SetOutput(0, ctx->builder()->ConstantR1(output)); + } +}; + +REGISTER_XLA_OP("InvertPermutation", InvertPermutationOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc new file mode 100644 index 0000000000..eced089b32 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of simple unary Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +// A subclass of a TlaUnaryOp must build the lambda computation that +// describes the scalar->scalar function to apply to each element of +// the input. +#define XLAJIT_MAKE_UNARY(Name, COMPUTATION) \ + class Name##Op : public XlaOpKernel { \ + public: \ + explicit Name##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ + void Compile(XlaOpKernelContext* ctx) { \ + xla::ComputationBuilder& b = *ctx->builder(); \ + xla::ComputationDataHandle x = ctx->Input(0); \ + xla::ComputationDataHandle y = COMPUTATION; \ + ctx->SetOutput(0, y); \ + } \ + }; \ + REGISTER_XLA_OP(#Name, Name##Op); + +// Return x if x>0, otherwise -x. +XLAJIT_MAKE_UNARY(Abs, b.Abs(x)); +XLAJIT_MAKE_UNARY(Ceil, b.Ceil(x)); +XLAJIT_MAKE_UNARY(Exp, b.Exp(x)); +XLAJIT_MAKE_UNARY(Floor, b.Floor(x)); +// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, b.Sign(x)); +// Return 1/x +XLAJIT_MAKE_UNARY(Inv, b.Div(XlaHelpers::One(&b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Reciprocal, b.Div(XlaHelpers::One(&b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Log, b.Log(x)); +XLAJIT_MAKE_UNARY(LogicalNot, b.LogicalNot(x)); +XLAJIT_MAKE_UNARY(Neg, b.Neg(x)); +XLAJIT_MAKE_UNARY(Rsqrt, + b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), -0.5))); +XLAJIT_MAKE_UNARY(Sigmoid, b.Map({x}, *ctx->GetOrCreateSigmoid(input_type(0)))); +XLAJIT_MAKE_UNARY(Softplus, + b.Log(b.Add(b.Exp(x), XlaHelpers::One(&b, input_type(0))))); +XLAJIT_MAKE_UNARY(Sqrt, + b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Square, b.Mul(x, x)); +XLAJIT_MAKE_UNARY(Tanh, b.Tanh(x)); + +#undef XLAJIT_MAKE_UNARY + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc new file mode 100644 index 0000000000..c5b2bdaf2d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA Unpack operator. + +#include +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +class UnpackOp : public XlaOpKernel { + public: + explicit UnpackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const int num = num_outputs(); + const TensorShape input_shape = ctx->InputShape(0); + + int axis = axis_; + if (axis < 0) axis += input_shape.dims(); + + OP_REQUIRES(ctx, 0 <= axis && axis < input_shape.dims(), + errors::InvalidArgument("axis = ", axis_, " not in [", + -input_shape.dims(), ", ", + input_shape.dims(), ")")); + + OP_REQUIRES( + ctx, input_shape.dims() > 0 && input_shape.dim_size(axis) == num, + errors::InvalidArgument("Input shape axis ", axis, " must equal ", num, + ", got shape ", input_shape.DebugString())); + + auto output_shape = input_shape; + output_shape.RemoveDim(axis); + + auto input = ctx->Input(0); + + std::vector start_indices(input_shape.dims(), 0); + std::vector limit_indices(input_shape.dims()); + for (int i = 0; i < input_shape.dims(); ++i) { + limit_indices[i] = input_shape.dim_size(i); + } + + for (int i = 0; i < num; ++i) { + start_indices[axis] = i; + limit_indices[axis] = i + 1; + auto slice = ctx->builder()->Slice(input, start_indices, limit_indices); + // Reshape to drop the 'axis' dimension. + auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); + ctx->SetOutput(i, result); + } + } + + private: + int axis_; +}; + +REGISTER_XLA_OP("Unpack", UnpackOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc new file mode 100644 index 0000000000..1f2bc01cf4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/literal_util.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/common_runtime/dma_helper.h" + +namespace tensorflow { + +Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { + literal->Clear(); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape())); + + xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal); + + // memcpy over the payload ... + // TODO(phawkins): handle string types. + size_t total_bytes = host_tensor.TotalBytes(); + if (total_bytes > 0) { + void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal); + const void* src_ptr = DMAHelper::base(&host_tensor); + memcpy(dst_ptr, src_ptr, total_bytes); + } + return Status::OK(); +} + +Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, + Tensor* host_tensor) { + xla::PrimitiveType primitive_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type)); + if (literal.shape().element_type() != primitive_type) { + return errors::InvalidArgument( + "Cannot convert literal of type ", + xla::PrimitiveType_Name(literal.shape().element_type()), + " to tensor of type ", DataTypeString(target_type)); + } + + TensorShape shape = XLAShapeToTensorShape(literal.shape()); + *host_tensor = Tensor(target_type, shape); + size_t total_bytes = host_tensor->TotalBytes(); + if (total_bytes > 0) { + const void* src_ptr = xla::LiteralUtil::InternalData(literal); + void* dst_ptr = DMAHelper::base(host_tensor); + memcpy(dst_ptr, src_ptr, total_bytes); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h new file mode 100644 index 0000000000..3e509375ef --- /dev/null +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA Literals. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Copies 'host_tensor' to an XLA Literal. Fails if the host_tensor has zero +// elements or is of an unsupported type. +Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); + +// Copies 'literal' to 'host_tensor', which is allocated of type . +// Fails if the literal's primitive type != +// DataTypeToPrimitiveType(target_type). Note that is not +// derivable from the type of , because multiple tensorflow types map +// to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in +// XLA). +Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, + Tensor* host_tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc new file mode 100644 index 0000000000..56993bc585 --- /dev/null +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/literal_util.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(LiteralUtil, LiteralToHostTensor) { + // int64 literal can only be converted to an int64 host tensor. + { + std::vector int64_values = {1, 2, 3}; + std::unique_ptr int64_values_literal = + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ( + "Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int64_values)); + } + + { + // Repeat tests with int32. + Tensor host_tensor; + std::vector int32_values = {10, 11}; + std::unique_ptr int32_values_literal = + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); + EXPECT_TRUE( + LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int32_values)); + + EXPECT_TRUE( + LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) + .ok()); + std::vector qint32_values = {10, 11}; + test::ExpectTensorEqual(host_tensor, + test::AsTensor(qint32_values)); + + EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", + LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) + .error_message()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/op_registrations.cc b/tensorflow/compiler/tf2xla/op_registrations.cc new file mode 100644 index 0000000000..d8a4dad4b3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/op_registrations.cc @@ -0,0 +1,502 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Kernel registrations for XLA JIT devices. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +// CPU JIT device registrations. + +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("_Arg").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("_ArrayToList")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("_ListToArray")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("_Retval").TypeConstraint("T", kCpuAllTypes)); + +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Abs").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Add").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("AddN").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("All")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Any")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("AvgPool").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("AvgPoolGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("BatchMatMul").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("BiasAdd").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("BiasAddV1").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("BiasAddGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("BroadcastGradientArgs")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Cast") + .TypeConstraint("SrcT", kCpuAllTypes) + .TypeConstraint("DstT", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Ceil").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Concat").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ConcatV2") + .TypeConstraint("T", kCpuAllTypes) + .TypeConstraint("Tidx", DT_INT32)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ConcatOffset")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Conv2D").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_CPU_XLA_JIT, + Name("Conv2DBackpropFilter").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_CPU_XLA_JIT, + Name("Conv2DBackpropInput").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_CPU_XLA_JIT, + Name("DepthwiseConv2dNative").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Diag").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("DiagPart").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Div").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("DynamicStitch").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Equal").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Exp").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("ExpandDims").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Fill").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Floor").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("FloorDiv").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("FloorMod").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Greater").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("GreaterEqual").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Inv").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Reciprocal").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("InvertPermutation").TypeConstraint("T", DT_INT32)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("L2Loss").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Less").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("LessEqual").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("LinSpace").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Log").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalAnd")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalNot")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalOr")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("LogSoftmax").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("LRN").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("LRNGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Maximum").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("MatMul").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("MatrixDiag").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("MatrixDiagPart").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Max").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("MaxPool").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("MaxPoolGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Mean").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Min").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Minimum").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Mod").TypeConstraint("T", kCpuIntTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Mul").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Neg").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("NotEqual").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Pack").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Pad").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Pow").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("PreventGradient").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Prod").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Range").TypeConstraint("Tidx", kCpuNumericTypes)); +// TODO(b/31361304): disabled because of XLA bugs. +// REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomStandardNormal")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomUniform")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomUniformInt")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Rank")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("RealDiv").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Relu").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Relu6").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("ReluGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Relu6Grad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Reshape").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Rsqrt").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("RsqrtGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Select").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Shape")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ShapeN")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Sigmoid").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("SigmoidGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Sign").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Size")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Slice").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Softmax").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_CPU_XLA_JIT, + Name("SoftmaxCrossEntropyWithLogits").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Softplus").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("SoftplusGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("SparseMatMul") + .TypeConstraint("Ta", kCpuFloatTypes) + .TypeConstraint("Tb", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Split").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("SplitV").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Square").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL( + DEVICE_CPU_XLA_JIT, + Name("SquaredDifference").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Squeeze").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Sqrt").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("StopGradient").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("StridedSlice").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("StridedSliceGrad").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Sub").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Sum").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("SymbolicGradient")); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Tanh").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("TanhGrad").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Tile").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Transpose").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("TruncateDiv").TypeConstraint("T", kCpuIntTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("TruncateMod").TypeConstraint("T", kCpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Unpack").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("ZerosLike").TypeConstraint("T", kCpuNumericTypes)); + +REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Const").TypeConstraint("dtype", + kCpuAllTypes)); +REGISTER_XLA_JIT_ONLY_KERNEL( + DEVICE_CPU_XLA_JIT, Name("Identity").TypeConstraint("T", kCpuAllTypes)); +REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_CPU_XLA_JIT, Name("NoOp")); + +// GPU JIT device registrations + +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("_Arg").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("_ArrayToList")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("_ListToArray")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("_Retval").TypeConstraint("T", kGpuAllTypes)); + +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Abs").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Add").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("AddN").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("All")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Any")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("AvgPool").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("AvgPoolGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("BatchMatMul").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("BiasAdd").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("BiasAddV1").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("BiasAddGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("BroadcastGradientArgs")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Cast") + .TypeConstraint("SrcT", kGpuAllTypes) + .TypeConstraint("DstT", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Ceil").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Concat").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("ConcatV2").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ConcatOffset")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Conv2D").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_GPU_XLA_JIT, + Name("Conv2DBackpropFilter").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_GPU_XLA_JIT, + Name("Conv2DBackpropInput").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_GPU_XLA_JIT, + Name("DepthwiseConv2dNative").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Diag").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("DiagPart").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Div").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("DynamicStitch").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Equal").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Exp").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("ExpandDims").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Fill").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Floor").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("FloorDiv").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("FloorMod").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Greater").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("GreaterEqual").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Inv").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Reciprocal").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("InvertPermutation").TypeConstraint("T", DT_INT32)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("L2Loss").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Less").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("LessEqual").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("LinSpace").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Log").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalAnd")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalNot")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalOr")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("LogSoftmax").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("LRN").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("LRNGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Maximum").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("MatMul").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("MatrixDiag").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("MatrixDiagPart").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Max").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("MaxPool").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("MaxPoolGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Mean").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Min").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Minimum").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Mod").TypeConstraint("T", kGpuIntTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Mul").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Neg").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("NotEqual").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Pack").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Pad").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Pow").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("PreventGradient").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Prod").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Range").TypeConstraint("Tidx", kGpuNumericTypes)); +// TODO(b/31361304): disabled because of XLA bugs. +// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomStandardNormal")); +// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomUniform")); +// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomUniformInt")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Rank")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("RealDiv").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Relu").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Relu6").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("ReluGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Relu6Grad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Reshape").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Rsqrt").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("RsqrtGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Select").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Shape")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ShapeN")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Sigmoid").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("SigmoidGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Sign").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Size")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Slice").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Softmax").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL( + DEVICE_GPU_XLA_JIT, + Name("SoftmaxCrossEntropyWithLogits").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Softplus").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("SoftplusGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("SparseMatMul") + .TypeConstraint("Ta", kGpuFloatTypes) + .TypeConstraint("Tb", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Split").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("SplitV").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Square").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL( + DEVICE_GPU_XLA_JIT, + Name("SquaredDifference").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Squeeze").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Sqrt").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("StopGradient").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("StridedSlice").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("StridedSliceGrad").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Sub").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Sum").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("SymbolicGradient")); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Tanh").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("TanhGrad").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Tile").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Transpose").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("TruncateDiv").TypeConstraint("T", kGpuIntTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("TruncateMod").TypeConstraint("T", kGpuNumericTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Unpack").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("ZerosLike").TypeConstraint("T", kGpuNumericTypes)); + +REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Const").TypeConstraint("dtype", + kGpuAllTypes)); +REGISTER_XLA_JIT_ONLY_KERNEL( + DEVICE_GPU_XLA_JIT, Name("Identity").TypeConstraint("T", kGpuAllTypes)); +REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_GPU_XLA_JIT, Name("NoOp")); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc new file mode 100644 index 0000000000..f5ecb51a5b --- /dev/null +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Convert an XLA Shape into the equivalent TensorFlow shape. +TensorShape XLAShapeToTensorShape(const xla::Shape& shape) { + TensorShape tensor_shape; + for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { + tensor_shape.AddDim(shape.dimensions(i)); + } + return tensor_shape; +} + +// Convert a TensorShape into the equivalent XLA Shape proto. +Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, + xla::Shape* shape) { + int rank = tensor_shape.dims(); + std::vector dimensions(rank); + std::vector layout(rank); + for (int d = 0; d < rank; ++d) { + dimensions[d] = tensor_shape.dim_size(d); + } + // XLA uses minor-to-major; Tensorflow uses major-to-minor. + std::iota(layout.rbegin(), layout.rend(), 0); + + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + + *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h new file mode 100644 index 0000000000..516dd636a9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA shapes. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +// Convert an XLA Shape into the equivalent TensorFlow shape. +TensorShape XLAShapeToTensorShape(const xla::Shape& shape); + +// Convert a TensorShape into the equivalent XLA Shape proto. Unlike Tensorflow, +// XLA shapes include the type. Not all `dtype` values can be represented by +// XLA, so this conversion may fail. +Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, + xla::Shape* shape); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc new file mode 100644 index 0000000000..ce25d63127 --- /dev/null +++ b/tensorflow/compiler/tf2xla/str_util.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/str_util.h" + +#include +#include +#include + +namespace tensorflow { +namespace str_util { + +void ReplaceAll(string* text, StringPiece from, StringPiece to) { + size_t pos = 0; + while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { + text->replace(pos, from.size(), to.data(), to.size()); + pos += to.size(); + if (from.empty()) { + pos++; // Match at the beginning of the text and after every byte + } + } +} + +void ReplaceAllPairs(string* text, + const std::vector>& replace) { + for (const std::pair& from_to : replace) { + ReplaceAll(text, from_to.first, from_to.second); + } +} + +} // namespace str_util +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h new file mode 100644 index 0000000000..4920b1a4d4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/str_util.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// String utilities that are esoteric enough that they don't belong in +// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally +// useful under xla. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace str_util { + +// Replace all non-overlapping occurrences of from with to in-place in text. If +// from is empty, it matches at the beginning of the text and after every byte. +void ReplaceAll(string* text, StringPiece from, StringPiece to); + +// Replace all non-overlapping occurrences of the given (from,to) pairs in-place +// in text. If from is empty, it matches at the beginning of the text and after +// every byte. Each (from,to) replacement pair is processed in the order it is +// given. +void ReplaceAllPairs(string* text, + const std::vector>& replace); + +} // namespace str_util +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc new file mode 100644 index 0000000000..f992007a34 --- /dev/null +++ b/tensorflow/compiler/tf2xla/str_util_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/str_util.h" + +#include +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace str_util { + +class ReplaceAllTest : public ::testing::Test { + protected: + void ExpectReplaceAll(string text, StringPiece from, StringPiece to, + StringPiece want) { + ReplaceAll(&text, from, to); + EXPECT_EQ(text, want); + } +}; + +TEST_F(ReplaceAllTest, Simple) { + ExpectReplaceAll("", "", "", ""); + ExpectReplaceAll("", "", "X", "X"); + ExpectReplaceAll("", "", "XYZ", "XYZ"); + ExpectReplaceAll("banana", "", "", "banana"); + ExpectReplaceAll("banana", "", "_", "_b_a_n_a_n_a_"); + ExpectReplaceAll("banana", "", "__", "__b__a__n__a__n__a__"); + ExpectReplaceAll("banana", "a", "a", "banana"); + ExpectReplaceAll("banana", "a", "", "bnn"); + ExpectReplaceAll("banana", "a", "X", "bXnXnX"); + ExpectReplaceAll("banana", "a", "XX", "bXXnXXnXX"); + ExpectReplaceAll("banana", "an", "an", "banana"); + ExpectReplaceAll("banana", "an", "", "ba"); + ExpectReplaceAll("banana", "an", "X", "bXXa"); + ExpectReplaceAll("banana", "an", "XY", "bXYXYa"); + ExpectReplaceAll("banana", "an", "XYZ", "bXYZXYZa"); + ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "X", "foo X baz X"); + ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "ABCDEFGHIJKLMNOP", + "foo ABCDEFGHIJKLMNOP baz ABCDEFGHIJKLMNOP"); +} + +class ReplaceAllPairsTest : public ::testing::Test { + protected: + void ExpectReplaceAllPairs( + string text, const std::vector>& replace, + StringPiece want) { + ReplaceAllPairs(&text, replace); + EXPECT_EQ(text, want); + } +}; + +TEST_F(ReplaceAllPairsTest, Simple) { + ExpectReplaceAllPairs("", {}, ""); + ExpectReplaceAllPairs("", {{"", ""}}, ""); + ExpectReplaceAllPairs("", {{"", "X"}}, "X"); + ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ"); + ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_"); + ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_"); + ExpectReplaceAllPairs("banana", {}, "banana"); + ExpectReplaceAllPairs("banana", {{"", ""}}, "banana"); + ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_"); + ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__"); + ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana"); + ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn"); + ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX"); + ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX"); + ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX"); + ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}", + {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}}, + "a0b123456789c0"); +} + +} // namespace str_util +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc new file mode 100644 index 0000000000..b54848f342 --- /dev/null +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { + switch (data_type) { + case tensorflow::DT_BOOL: + *type = xla::PRED; + return Status::OK(); + case tensorflow::DT_INT8: + *type = xla::S8; + return Status::OK(); + case tensorflow::DT_INT16: + *type = xla::S16; + return Status::OK(); + case tensorflow::DT_INT32: + *type = xla::S32; + return Status::OK(); + case tensorflow::DT_INT64: + *type = xla::S64; + return Status::OK(); + case tensorflow::DT_UINT8: + *type = xla::U8; + return Status::OK(); + case tensorflow::DT_UINT16: + *type = xla::U16; + return Status::OK(); + case tensorflow::DT_HALF: + *type = xla::F16; + return Status::OK(); + case tensorflow::DT_FLOAT: + *type = xla::F32; + return Status::OK(); + case tensorflow::DT_DOUBLE: + *type = xla::F64; + return Status::OK(); + case tensorflow::DT_QUINT8: + *type = xla::U8; + return Status::OK(); + case tensorflow::DT_QINT32: + *type = xla::S32; + return Status::OK(); + default: + return errors::InvalidArgument( + "Unsupported type in DataTypeToPrimitiveType ", + DataTypeString(data_type)); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h new file mode 100644 index 0000000000..bda667eb1f --- /dev/null +++ b/tensorflow/compiler/tf2xla/type_util.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Converts a Tensorflow DataType to an XLA PrimitiveType. +Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc new file mode 100644 index 0000000000..86a53c929e --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -0,0 +1,203 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace tensorflow { + +const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT"; +const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; + +// The XlaCompilationAllocator doesn't actually back any Tensors with storage +// buffers of values: instead for each Tensor it stores a +// XlaExpression which corresponds to the XLA computation +// represented by the Tensor. +class XlaCompilationAllocator : public Allocator { + public: + XlaCompilationAllocator() {} + ~XlaCompilationAllocator() override {} + + string Name() override { return "tla_jit"; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + // Regardless of the size requested, always allocate a + // XlaExpression. Respect the aligment request because there is + // alignment checking even for Tensors whose data is never + // accessed. + void* p = port::aligned_malloc(sizeof(XlaExpression), alignment); + XlaExpression* expression = reinterpret_cast(p); + new (expression) XlaExpression(); + return expression; + } + + void DeallocateRaw(void* ptr) override { + XlaExpression* expression = reinterpret_cast(ptr); + expression->~XlaExpression(); + port::aligned_free(ptr); + } + + // Make sure that even tensors with 0 elements have allocated + // buffers, so they get ids to track. + bool ShouldAllocateEmptyTensors() override { return true; } + + void GetStats(AllocatorStats* stats) override { stats->Clear(); } + + private: + // Don't run any constructors or destructors for complex objects, + // since there is no backing store for the tensor to run them + // on. strings are the only complex objects currently stored in + // Tensors. If others are added, this set of overrides must be + // extended to include them. + void RunStringCtor(string* p, size_t n) override {} + void RunStringDtor(string* p, size_t n) override {} + void RunResourceCtor(ResourceHandle* p, size_t n) override {} + void RunResourceDtor(ResourceHandle* p, size_t n) override {} +}; + +XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, + DeviceType type) + : LocalDevice(options, + Device::BuildDeviceAttributes( + "", type, Bytes(256 << 20), DeviceLocality(), + strings::StrCat("device: XLA JIT device ", type.type())), + cpu_allocator()), + allocator_(new XlaCompilationAllocator()) {} + +XlaCompilationDevice::~XlaCompilationDevice() {} + +Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) { + return allocator_.get(); +} + +Status XlaCompilationDevice::Sync() { return Status::OK(); } + +Status XlaCompilationDevice::MakeTensorFromProto( + const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + return errors::InvalidArgument( + "Tla JIT Device should not parse tensor from proto"); +} + +// Is platform 'id' supported by XLA? +static bool IsPlatformSupported(perftools::gputools::Platform::Id id) { + auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id); + if (!platform.ok()) return false; + return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok(); +} + +XlaOpRegistry::XlaOpRegistry() = default; +XlaOpRegistry::~XlaOpRegistry() = default; + +/* static */ void XlaOpRegistry::RegisterJitDevice( + const string& device_name, const string& jit_device_name, + bool requires_jit) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto result = registry.jit_devices_.emplace( + device_name, std::make_pair(jit_device_name, requires_jit)); + CHECK(result.second || result.first->second.first == jit_device_name); +} + +/* static */ bool XlaOpRegistry::GetJitDevice(const string& device_name, + const string** jit_device_name, + bool* requires_jit) { + XlaOpRegistry& registry = Instance(); + + // Lazily register the CPU and GPU JIT devices the first time GetJitDevice is + // called. + static void* registration = [®istry]() { + mutex_lock lock(registry.mutex_); + if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) { + registry.jit_devices_[DEVICE_CPU] = {DEVICE_CPU_XLA_JIT, false}; + } + if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) { + registry.jit_devices_[DEVICE_GPU] = {DEVICE_GPU_XLA_JIT, false}; + } + return nullptr; + }(); + (void)registration; + + mutex_lock lock(registry.mutex_); + auto it = registry.jit_devices_.find(device_name); + if (it == registry.jit_devices_.end()) return false; + if (jit_device_name) *jit_device_name = &it->second.first; + if (requires_jit) *requires_jit = it->second.second; + return true; +} + +void XlaOpRegistry::RegisterJitKernels() { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + + if (registry.jit_kernels_registered_) return; + registry.jit_kernels_registered_ = true; + + for (const auto& entry : registry.kernels_) { + for (const XlaKernel& k : entry.second) { + auto it = registry.ops_.find(k.kernel_def->op()); + CHECK(it != registry.ops_.end()) << "Missing XLA op registration for op " + << k.kernel_def->op(); + registry.kernel_registrars_.emplace_back( + new kernel_factory::OpKernelRegistrar(new KernelDef(*k.kernel_def), + "XlaJitOp", it->second)); + } + } +} + +std::vector XlaOpRegistry::DeviceKernels( + const string& jit_device_type) { + std::vector kernels; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const XlaKernel& k : registry.kernels_.at(jit_device_type)) { + if (!k.jit_only) { + kernels.push_back(k.kernel_def.get()); + } + } + return kernels; +} + +XlaOpRegistry& XlaOpRegistry::Instance() { + static XlaOpRegistry* r = new XlaOpRegistry; + return *r; +} + +XlaOpRegistrar::XlaOpRegistrar(StringPiece name, + XlaOpRegistry::Factory factory) { + XlaOpRegistry& registry = XlaOpRegistry::Instance(); + mutex_lock lock(registry.mutex_); + CHECK(registry.ops_.emplace(name.ToString(), factory).second) + << "Duplicate XLA op registration " << name; +} + +XlaKernelRegistrar::XlaKernelRegistrar(bool jit_only, const KernelDef* def) { + XlaOpRegistry& registry = XlaOpRegistry::Instance(); + mutex_lock lock(registry.mutex_); + registry.kernels_[def->device_type()].push_back(XlaOpRegistry::XlaKernel{ + jit_only, std::unique_ptr(def)}); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h new file mode 100644 index 0000000000..f4b95b874b --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -0,0 +1,214 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Names of the XLA JIT devices. These are not user-visible, and are used +// internally by the JIT to perform symbolic execution of a Tensorflow graph. + +extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" +extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" + +constexpr std::array kCpuAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kCpuIntTypes = {{DT_INT32, DT_INT64}}; +constexpr std::array kCpuFloatTypes = {{DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kCpuNumericTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}}; + +constexpr std::array kGpuAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kGpuIntTypes = {{DT_INT32, DT_INT64}}; +constexpr std::array kGpuFloatTypes = {{DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kGpuNumericTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}}; + +// Class is declared and defined in tla_jit_device.cc, reference +// included here only so the XlaCompilationDevice allocator_ member can be +// defined. +class XlaCompilationAllocator; + +// Deliberately don't register the device factory because we *never* +// want soft placement to put Ops on an JIT device. Tests can include +// the tla_jit_test_deps target which registers the factory, and when +// using JIT in practice, the device is created manually not using a +// factory. + +// This is a 'dummy' TensorFlow device that is only used to execute a +// subgraph of XLA compilation Ops to construct a compiled version +// of the subgraph's computation. It has a 'dummy' allocator that +// backs each Tensor with metadata indicating the computation the +// Tensor represents. +class XlaCompilationDevice : public LocalDevice { + public: + XlaCompilationDevice(const SessionOptions& options, DeviceType type); + + ~XlaCompilationDevice() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override; + + Status Sync() override; + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + private: + std::unique_ptr allocator_; +}; + +// Class that manages registrations of operators and devices for the XLA JIT. +// Not thread-safe. +class XlaOpRegistry { + public: + typedef OpKernel* (*Factory)(OpKernelConstruction*); + + // Registers 'jit_device_name' as the JIT device corresponding to + // 'device_name'. If 'requires_jit' is true, then operators placed on this + // device must be JIT-compiled. Dies if a conflicting registration already + // exists. + static void RegisterJitDevice(const string& device_name, + const string& jit_device_name, + bool requires_jit); + + // Returns the JIT device name associated with 'device_name', setting + // 'jit_device_name' and 'requires_jit', if they are not null. Returns false + // and leaves 'jit_device_name' and 'requires_jit' unchanged if no matching + // JIT device is registered. + static bool GetJitDevice(const string& device_name, + const string** jit_device_name, bool* requires_jit); + + // Registers all JIT kernels on JIT devices, if not already registered. + // Does nothing otherwise. + static void RegisterJitKernels(); + + // Returns KernelDefs for JIT ops registered on 'jit_device_type'. + // Does not include kernels registered using REGISTER_XLA_JIT_ONLY_KERNEL. + static std::vector DeviceKernels( + const string& jit_device_type); + + private: + friend class XlaKernelRegistrar; + friend class XlaOpRegistrar; + + static XlaOpRegistry& Instance(); + + XlaOpRegistry(); + ~XlaOpRegistry(); + + mutex mutex_; + + // Map from Tensorflow device names to the corresponding JIT device names. + std::unordered_map> jit_devices_ + GUARDED_BY(mutex_); + + // Map from operator name to OpKernel factory, populated by REGISTER_XLA_OP. + std::unordered_map ops_ GUARDED_BY(mutex_); + + // Have we already registered the JIT kernels on the JIT devices? + bool jit_kernels_registered_ = false; + + struct XlaKernel { + // Should this kernel be registered only on JIT devices, without a dummy + // kernel registered on the corresponding XLA device? + bool jit_only; + + // KernelDef as built by REGISTER_XLA_KERNEL. + std::unique_ptr kernel_def; + }; + + // Map from JIT device name to a vector of XLA kernel descriptors. + std::unordered_map> kernels_ + GUARDED_BY(mutex_); + + // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel + // registrations created by RegisterJitKernels() and RegisterDeviceKernels(). + std::vector> + kernel_registrars_ GUARDED_BY(mutex_); +}; + +// REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: +// REGISTER_XLA_OP("Add", AddOp); +// where 'AddOp' is the name of a JIT OpKernel class that implements "Add". +// +// We don't use a variadic macro here because we don't expect JIT operators to +// be templated. + +#define REGISTER_XLA_OP(NAME, OP) \ + REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) + +// REGISTER_XLA_KERNEL() associates an XLA OpKernel with a particular device and +// set of type constraints, e.g., +// REGISTER_XLA_KERNEL(DEVICE_XLA_CPU_JIT, +// Name("Relu").TypeConstraint("T", DT_FLOAT)); +// +// REGISTER_XLA_JIT_ONLY_KERNEL is similar to REGISTER_XLA_KERNEL(), but causes +// XlaOpRegistry::RegisterDeviceKernels() to ignore the kernel. + +#define REGISTER_XLA_KERNEL(DEVICE, BUILDER) \ + REGISTER_XLA_KERNEL_UNIQ_HELPER(__COUNTER__, DEVICE, BUILDER, false) + +#define REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE, BUILDER) \ + REGISTER_XLA_KERNEL_UNIQ_HELPER(__COUNTER__, DEVICE, BUILDER, true) + +// Implementation details. + +class XlaOpRegistrar { + public: + XlaOpRegistrar(StringPiece name, XlaOpRegistry::Factory factory); +}; + +#define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, NAME, OP) \ + REGISTER_XLA_OP_UNIQ(COUNTER, NAME, OP) + +#define REGISTER_XLA_OP_UNIQ(CTR, NAME, OP) \ + static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ + NAME, [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { return new OP(context); }); + +// Implementation details. +class XlaKernelRegistrar { + public: + XlaKernelRegistrar(bool jit_only, const KernelDef* def); +}; + +#define REGISTER_XLA_KERNEL_UNIQ_HELPER(COUNTER, DEVICE, BUILDER, JIT_ONLY) \ + REGISTER_XLA_KERNEL_UNIQ(COUNTER, DEVICE, BUILDER, JIT_ONLY) + +#define REGISTER_XLA_KERNEL_UNIQ(CTR, DEVICE, BUILDER, JIT_ONLY) \ + static ::tensorflow::XlaKernelRegistrar \ + xla_kernel_registrar__body__##CTR##__object( \ + JIT_ONLY, \ + ::tensorflow::register_kernel::BUILDER.Device(DEVICE).Build()); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc new file mode 100644 index 0000000000..e46c2a3148 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -0,0 +1,405 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" + +#include + +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +namespace { + +bool HasRetval(const Graph& graph) { + for (const Node* n : graph.nodes()) { + if (n->type_string() == "_Retval") return true; + } + return false; +} + +Status CheckSignature(const DataTypeVector& tf_types, + const xla::Shape& xla_shape) { + if (xla::ShapeUtil::IsTuple(xla_shape)) { + if (xla::ShapeUtil::TupleElementCount(xla_shape) != tf_types.size()) { + return errors::Internal("XLA shape has ", + xla::ShapeUtil::TupleElementCount(xla_shape), + " elements while function has ", tf_types.size()); + } + for (int i = 0; i < tf_types.size(); ++i) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[i], &type)); + if (type != + xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type()) { + return errors::Internal( + "element ", i, " has XLA type ", + xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type(), + " and TensorFlow type ", DataTypeString(tf_types[i])); + } + } + } else { + if (tf_types.size() != 1) { + return errors::Internal("Expected singleton type, got ", tf_types.size(), + " types"); + } + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[0], &type)); + if (type != xla_shape.element_type()) { + return errors::Internal("singleton element has XLA type ", + xla_shape.element_type(), " and TensorFlow type ", + DataTypeString(tf_types[0])); + } + } + return Status::OK(); +} + +} // namespace + +XlaCompiler::XlaCompiler(const XlaCompiler::Options& options) + : client_(options.client), + allow_cpu_custom_calls_(options.allow_cpu_custom_calls), + local_executable_has_hybrid_result_( + options.local_executable_has_hybrid_result), + next_step_id_(1), + device_(new XlaCompilationDevice(SessionOptions(), options.device_type)), + device_mgr_({device_}) {} + +XlaCompiler::~XlaCompiler() = default; + +int64 XlaCompiler::NextStepId() { + mutex_lock l(mu_); + return next_step_id_++; +} + +Status XlaCompiler::CompileFunction( + FunctionLibraryRuntime* flr, const NameAttrList& function, + const std::vector& args, + XlaCompiler::CompilationResult* result) { + const string function_id = Canonicalize(function.name(), function.attr()); + VLOG(1) << "XlaCompiler::CompileFunction " << function_id; + + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(function.name(), function.attr(), &handle)); + + const FunctionBody* fbody = flr->GetFunctionBody(handle); + CHECK(fbody); + + return CompileFunctionBody(flr, *fbody, function_id, args, + /*use_tuple_arg=*/false, result); +} + +Status XlaCompiler::CompileSubComputation(FunctionLibraryRuntime* flr, + const NameAttrList& function, + const xla::Shape& input_shape, + const xla::Shape& output_shape, + xla::Computation* computation) { + const string function_id = Canonicalize(function.name(), function.attr()); + VLOG(1) << "XlaCompiler::CompileSubComputation " << function_id; + + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(function.name(), function.attr(), &handle)); + + const FunctionBody* fbody = flr->GetFunctionBody(handle); + CHECK(fbody); + + TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, input_shape)); + TF_RETURN_IF_ERROR(CheckSignature(fbody->ret_types, output_shape)); + + const bool use_tuple_arg = xla::ShapeUtil::IsTuple(input_shape); + + std::vector args(fbody->arg_types.size()); + if (use_tuple_arg) { + for (int i = 0; i < args.size(); ++i) { + xla::Shape xla_shape = + xla::ShapeUtil::GetTupleElementShape(input_shape, i); + args[i].type = fbody->arg_types[i]; + args[i].shape = XLAShapeToTensorShape(xla_shape); + args[i].parameter = i; + } + } else { + args[0].type = fbody->arg_types[0]; + args[0].shape = XLAShapeToTensorShape(input_shape); + args[0].parameter = 0; + } + + CompilationResult result; + TF_RETURN_IF_ERROR(CompileFunctionBody(flr, *fbody, function_id, args, + use_tuple_arg, &result)); + + if (!xla::ShapeUtil::Compatible(result.xla_output_shape, output_shape)) { + return errors::Internal("output shape mismatch from compilation"); + } + *computation = std::move(result.computation); + + return Status::OK(); +} + +Status XlaCompiler::CompileFunctionBody( + FunctionLibraryRuntime* flr, const FunctionBody& fbody, + const string& function_id, const std::vector& args, + bool use_tuple_arg, XlaCompiler::CompilationResult* result) { + VLOG(1) << "XlaCompiler::CompileFunctionBody " << function_id; + + std::unique_ptr graph(new Graph(flr->GetFunctionLibraryDefinition())); + CopyGraph(*fbody.graph, graph.get()); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile( + strings::StrCat("xla_jit_raw_input_", function_id), *graph); + } + + if (!HasRetval(*graph)) { + VLOG(1) << "Graph has no retvals. Skipping compilation."; + return Status::OK(); + } + + // Optimize the graph to before running throught the translator. + // TODO(pbar) The constant folder currently does not simplify int32 operations + // for devices other than CPU. + OptimizerOptions opts; + GraphOptimizer optimizer(opts); + Graph* g = graph.release(); + OptimizeGraph(flr, &g); + graph.reset(g); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile( + strings::StrCat("xla_jit_final_graph_", function_id), *graph); + } + + VLOG(1) << "===================================================="; + TF_RETURN_IF_ERROR(CompileGraph(function_id, std::move(graph), flr, args, + use_tuple_arg, result)); + VLOG(1) << "===================================================="; + + return Status::OK(); +} + +Status XlaCompiler::BuildExecutable( + const XlaCompiler::CompilationResult& result, + std::unique_ptr* executable) { + VLOG(2) << "Compiling to local executable"; + xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); + + std::vector argument_layouts( + result.xla_input_shapes.size()); + for (int i = 0; i < result.xla_input_shapes.size(); ++i) { + argument_layouts[i] = &result.xla_input_shapes[i].second; + } + if (result.requires_runtime_context) { + // The final arg is the XlaLocalRuntimeContext*. + argument_layouts.push_back(&opaque_shape); + } + xla::LocalClient* local_client = static_cast(client()); + xla::ExecutableBuildOptions build_options; + build_options.set_device_ordinal(local_client->default_device_ordinal()); + build_options.set_platform(local_client->platform()); + build_options.set_result_layout(result.xla_output_shape); + build_options.set_has_hybrid_result(local_executable_has_hybrid_result_); + + auto compile_result = local_client->Compile(result.computation, + argument_layouts, build_options); + if (!compile_result.ok()) { + return compile_result.status(); + } + *executable = std::move(compile_result.ValueOrDie()); + return Status::OK(); +} + +namespace { + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status cleanup_status; + auto step_container = xla::MakeUnique( + step_id, [&cleanup_status, device](const string& name) { + cleanup_status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + // Create a LocalExecutor that will own and run the graph. + LocalExecutorParams exec_params; + exec_params.device = device; + exec_params.function_library = flib; + exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) { + return flib->CreateKernel(ndef, kernel); + }; + exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; + Executor* exec_ptr = nullptr; + TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr)); + std::unique_ptr exec(exec_ptr); + // At this point ownership of the graph has been transferred to exec. + + auto runner = [](Executor::Args::Closure c) { + // TODO(misard) Temporarily just schedule c eagerly while we + // decide what to do about the fact that the ComputationBuilder is + // thread-compatible, but we don't really want Op writers to have + // to remember to acquire a lock around every call to + // ComputationBuilder. One possibility is to add the (generally + // useful) ability to run a single-threaded Executor based on an + // option in LocalExecutorParams. Another is to automagically + // acquire a lock around ComputationBuilder calls using some + // wrapper or RAII funny business. + c(); + }; + + // Run the graph symbolically, turning the graph into an XLA computation. + Executor::Args exec_args; + exec_args.step_id = step_id; + exec_args.step_container = step_container.get(); + exec_args.runner = runner; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + exec->Run(exec_args), + "Conversion from TensorFlow graph to XLA computation failed."); + + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return cleanup_status; +} + +} // namespace + +Status XlaCompiler::CompileGraph(string const& name, + std::unique_ptr graph, + FunctionLibraryRuntime* flib, + const std::vector& args, + bool use_tuple_arg, + CompilationResult* result) { + VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + + // Converts the input shapes into xla::Shape instances. + result->xla_input_shapes.reserve(args.size()); + for (int i = 0; i < args.size(); ++i) { + if (args[i].parameter < 0) { + continue; + } + result->xla_input_shapes.push_back(std::make_pair(i, xla::Shape())); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + args[i].type, args[i].shape, &result->xla_input_shapes.back().second)); + } + + XlaContext* xla_context = + new XlaContext(client(), name, allow_cpu_custom_calls_); + core::ScopedUnref xla_context_unref(xla_context); + + TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg)); + + TF_RETURN_IF_ERROR( + ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId())); + + std::vector compile_time_constants; + int num_nonconst_outputs; + TF_RETURN_IF_ERROR(xla_context->CollectResults( + &result->computation, &result->requires_runtime_context, + &compile_time_constants, &num_nonconst_outputs)); + + result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs); + for (const auto& c : compile_time_constants) { + if (!c.status.ok()) { + Status constant_status = c.status; + errors::AppendToMessage(&constant_status, + "Failed evaluating constant XLA return " + "value ", + c.index); + return constant_status; + } + if (c.index >= result->outputs.size()) { + return errors::InvalidArgument("Invalid argument index ", c.index); + } + OutputDescription& output = result->outputs[c.index]; + output.shape = c.value.shape(); + output.is_constant = true; + output.constant_value = c.value; + } + + if (result->computation.IsNull()) { + return Status::OK(); + } + + // Compute the output shapes, if there is a computation with non-constant + // outputs. + auto computation_shape = client()->GetComputationShape(result->computation); + if (!computation_shape.ok()) { + return computation_shape.status(); + } + + result->xla_output_shape.Swap( + computation_shape.ValueOrDie()->mutable_result()); + + auto num_non_constant_outputs = + (xla::ShapeUtil::IsTuple(result->xla_output_shape)) + ? xla::ShapeUtil::TupleElementCount(result->xla_output_shape) + : 1; + // Tensorflow expects a major-to-minor order of results. + if (1 == num_non_constant_outputs) { + xla::Shape& s = result->xla_output_shape; + auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); + minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } else { + for (xla::Shape& s : *result->xla_output_shape.mutable_tuple_shapes()) { + auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); + minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + } + + // Converts the output shapes to TensorShapes. + int computation_output = 0; + for (int i = 0; i < result->outputs.size(); ++i) { + if (!result->outputs[i].is_constant) { + CHECK_LT(computation_output, num_non_constant_outputs); + if (num_non_constant_outputs > 1) { + result->outputs[i].shape = + XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( + result->xla_output_shape, computation_output)); + } else { + result->outputs[i].shape = + XLAShapeToTensorShape(result->xla_output_shape); + } + ++computation_output; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h new file mode 100644 index 0000000000..0b882d60a1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -0,0 +1,203 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// The XlaCompiler class is responsible for compilation of a self-contained +// subgraph of a TensorFlow computation using the XLA linear algebra runtime. +// It does a symbolic execution of the graph starting from specific input +// shapes, using a JIT device to convert operators into XLA computations. +// +// It is typically invoked from an `_XlaLaunch` operator once the shapes +// of all input parameters to the computation are known. This is +// because the symbolic execution requires known shapes for all operations. +class XlaCompiler { + public: + // Describes how to derive the value of each _Arg node in the graph/function + // being compiled. Each argument must be either a parameter of the generated + // XLA computation (parameter >= 0), or a compile time constant + // (parameter < 0). + struct Argument { + // The type of the argument. + DataType type; + + // The shape of the argument. + TensorShape shape; + + // The parameter number of this argument to the XLA computation. < 0 + // means this is a compile-time constant argument. + int parameter; + + // The value of the argument, if it is a compile-time constant. Must be a + // host-memory tensor. + Tensor constant_value; + + // The name of this argument, used for debugging. + string name; + }; + + struct OutputDescription { + // Shape of the output. + TensorShape shape; + + // Constant output value, if known to be constant at JIT compilation time. + // 'Tensor' is in host memory. + bool is_constant = false; + Tensor constant_value; + }; + + struct CompilationResult { + // Vector of (Tensorflow input number, XLA shape) pairs that describe + // the arguments of the compiled XLA computation. (Because of constant + // inputs, the arguments to the XLA computation are a subset of the + // inputs passed to the JIT.) + std::vector> xla_input_shapes; + + // Does the computation require the local runtime context to be passed as + // the last argument? + bool requires_runtime_context = false; + + // Output shape in XLA format. This is a tuple if and only if + // there are multiple non-constant outputs. + xla::Shape xla_output_shape; + + // TensorFlow shapes of outputs, together with the values of any + // constant arguments. Vector indexed by Tensorflow _Retval number, + // containing both constant and non-constant arguments. + std::vector outputs; + + // The XLA computation built from the tensorflow subgraph. May be null + // if the output consists solely of compile-time constants. + xla::Computation computation; + }; + + struct Options { + // Name of the compilation device to use. + DeviceType device_type = DeviceType(""); + + xla::Client* client = nullptr; + + // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() + // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed + // to the computation. + bool allow_cpu_custom_calls = false; + + // If 'local_executable_has_hybrid_result', the top-level pointers of the + // result tuple of compiled programs are stored in host memory and the + // nested buffers in device memory, otherwise the whole result tuple is + // stored in device memory. + bool local_executable_has_hybrid_result = false; + }; + + explicit XlaCompiler(const Options& options); + ~XlaCompiler(); + + // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. + // `args` describes the arguments to the function, each of which must either + // be a parameter to the XLA computation or a compile-time constant. + // Writes the compiled output to `result`. + // + // The generated XLA computation returns a tuple containing only the + // non-constant outputs as a function of the input arguments. Constant + // arguments are returned as host memory tensors in the output list and are + // not included in the XLA computation's outputs. The XLA computation is + // null if there are no data-dependent outputs. + Status CompileFunction(FunctionLibraryRuntime* flr, + const NameAttrList& fn_name_attrs, + const std::vector& args, + CompilationResult* result); + + // Compiles a tensorflow::Graph into an xla::Computation. + // Similar to CompileFunction, but takes a Graph as input rather than a + // function. + // If `use_tuple_arg` is true, the compilation takes all of its arguments as + // a single tuple. + Status CompileGraph(string const& name, std::unique_ptr graph, + FunctionLibraryRuntime* flr, + const std::vector& args, bool use_tuple_arg, + CompilationResult* result); + + // Helper function that compiles a function to an XLA computation suitable + // for use as a subroutine in other Computations, e.g., the body of a + // While loop. + // + // The emitted Computation takes a single input parameter with + // input_shape. If this is a tuple then the tuple element shapes + // must match the types of the function's _Arg nodes. If input_shape + // is not a tuple then the function must have a single _Arg node + // with the same type as input_shape. The shapes of the _Arg values + // will be compiled to match input_shape. + // + // The emitted Computation also returns a single value. If output_shape is a + // tuple the tuple elements' types and shapes must match the compiled + // function's _Retval nodes. If output_shape is not a tuple the + // function must have a single _Retval node with the correct type + // (and shape after compilation). + Status CompileSubComputation(FunctionLibraryRuntime* flr, + const NameAttrList& fn_name_attrs, + const xla::Shape& input_shape, + const xla::Shape& output_shape, + xla::Computation* computation); + + // Takes <*result>, which has been compiled from a Tensorflow subgraph to a + // XLA computation already, and generates an XLA LocalExecutable `executable`. + Status BuildExecutable(const CompilationResult& result, + std::unique_ptr* executable); + + xla::Client* client() const { return client_; } + XlaCompilationDevice* device() const { return device_; } + const DeviceMgr* device_mgr() const { return &device_mgr_; } + + private: + // Does the real work of Compile() and CompileToComputation(). + Status CompileFunctionBody(FunctionLibraryRuntime* function_library, + const FunctionBody& function_body, + const string& name, + const std::vector& args, + bool use_tuple_arg, CompilationResult* result); + + xla::Client* client_; // Not owned. + const bool allow_cpu_custom_calls_; + const bool local_executable_has_hybrid_result_; + + // Returns the next step sequence number. + int64 NextStepId(); + + mutex mu_; + + // Internal sequence number for steps executed on the compilation device. + int64 next_step_id_ GUARDED_BY(mu_); + + XlaCompilationDevice* device_; // Owned by device_mgr_ + DeviceMgr device_mgr_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc new file mode 100644 index 0000000000..ad8fc3f205 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -0,0 +1,331 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_context.h" + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +XlaExpression::XlaExpression() : has_constant_value_(false) {} + +void XlaExpression::set_handle(const xla::ComputationDataHandle& h) { + handle_ = h; +} +const xla::ComputationDataHandle& XlaExpression::handle() const { + return handle_; +} + +void XlaExpression::set_constant_value(Tensor value) { + has_constant_value_ = true; + constant_value_ = std::move(value); +} + +const char XlaContext::kXlaContextResourceName[] = "_xla_context"; + +// Looks up the context associated with the current step. It is stored +// in a resource container managed by the device. +/* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) { + // When an Op kernel wants to use an XLA JIT context, the + // per-step context is looked up in the resource manager. The + // JIT will prepopulate the JITContext. + XlaContext* context; + TF_CHECK_OK(ctx->resource_manager()->Lookup( + ctx->step_container()->name(), kXlaContextResourceName, &context)); + // The resource manager handed us a fresh reference to 'context', but retains + // a reference itself so the context won't be freed. The resource manager will + // outlive the JIT compilation. + context->Unref(); + return *context; +} + +Status XlaContext::BuildArguments(std::vector args, + bool use_tuple_arg) { + args_ = std::move(args); + use_tuple_arg_ = use_tuple_arg; + + // Compute the number of parameters, verify that they are sequential starting + // from 0 + num_parameters_ = 0; + for (const XlaCompiler::Argument& arg : args_) { + if (arg.parameter < 0) continue; + if (num_parameters_ != arg.parameter) { + return errors::InvalidArgument( + "Parameter numbers to JIT compilation are not consecutive starting " + "from 0"); + } + ++num_parameters_; + + if (arg.shape.num_elements() == 0) { + return errors::InvalidArgument( + "Non-constant argument must have a non-zero number of elements."); + } + } + if (num_parameters_ == 0) return Status::OK(); + + parameters_.resize(num_parameters_); + + std::vector parameter_shapes(num_parameters_); + for (int i = 0; i < args_.size(); ++i) { + const XlaCompiler::Argument& arg = args_[i]; + if (arg.parameter < 0) continue; + // Computes the shapes of non-constant arguments. + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type)); + xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(), + ¶meter_shapes[arg.parameter]); + } + + if (use_tuple_arg_ && num_parameters_ > 0) { + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes); + xla::ComputationDataHandle tuple = + builder().Parameter(0, tuple_shape, "arg_tuple"); + for (int i = 0; i < args_.size(); ++i) { + const XlaCompiler::Argument& arg = args_[i]; + if (arg.parameter < 0) continue; + parameters_[arg.parameter] = + builder().GetTupleElement(tuple, arg.parameter); + } + } else { + for (int i = 0; i < args_.size(); ++i) { + const XlaCompiler::Argument& arg = args_[i]; + if (arg.parameter < 0) continue; + parameters_[arg.parameter] = + builder().Parameter(arg.parameter, parameter_shapes[arg.parameter], + strings::StrCat("arg", i)); + } + } + return Status::OK(); +} + +Status XlaContext::CollectResults( + xla::Computation* computation, bool* requires_runtime_context, + std::vector* compile_time_constants, + int* num_nonconst_outputs) { + mutex_lock l(mu_); + + bool return_singleton = (1 == retval_.size()); + + xla::ComputationDataHandle handle; + if (return_singleton) { + handle = retval_[0].second; + + // TODO(b/31775371): to workaround bug, add a no-op computation that is + // guaranteed to be constructed after all of the formal parameters to the + // computation. + handle = builder().GetTupleElement(builder().Tuple({handle}), 0); + + // Ensure that the retval is returned even if another computation + // was mistakenly placed on the ComputationBuilder. + TF_CHECK_OK(builder().SetReturnValue(handle)); + } else { + if (!retval_.empty()) { + // There is at least one data-dependent expression: combine them + // into a Tuple in index order before compiling. + VLOG(1) << "Making the retval tuple."; + std::sort(retval_.begin(), retval_.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + std::vector elems; + elems.reserve(retval_.size()); + for (const std::pair& r : retval_) { + elems.push_back(r.second); + } + // Make a tuple from the vector of handles. + handle = builder().Tuple(elems); + } + } + + if (handle.handle() > 0) { + // Build the full computation. The return value is the handle + // constructed above. + xla::StatusOr computation_status = builder().Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + } + + // Make sure the compile time constants are in RetVal index order. + std::sort(compile_time_constant_.begin(), compile_time_constant_.end(), + [](const ConstRetVal& a, const ConstRetVal& b) { + return a.index < b.index; + }); + + // Fill in the result details and return. + *compile_time_constants = std::move(compile_time_constant_); + *requires_runtime_context = has_context_parameter_; + *num_nonconst_outputs = retval_.size(); + return Status::OK(); +} + +XlaContext::XlaContext(xla::Client* client, const string& computation_name, + bool allow_cpu_custom_calls) + : xla_builder_(client, computation_name), + allow_cpu_custom_calls_(allow_cpu_custom_calls) {} + +const xla::ComputationDataHandle& +XlaContext::GetOrCreateRuntimeContextParameter() { + mutex_lock lock(mu_); + CHECK(allow_cpu_custom_calls_); + CHECK(!use_tuple_arg_); + if (has_context_parameter_) return context_parameter_; + has_context_parameter_ = true; + context_parameter_ = xla_builder_.Parameter( + num_parameters_, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); + return context_parameter_; +} + +string XlaContext::DebugString() { return "TLA JIT context"; } + +// This is called by the Retval Op to associate a computed value +// with a specific return value of the subgraph. +void XlaContext::AddRetval(int retval_index, + const xla::ComputationDataHandle& handle) { + VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; + // Add the return value to the list being built up. The executor + // is multi-threaded so this has to happen under the + // lock. + mutex_lock l(mu_); + retval_.emplace_back(retval_index, handle); +} + +Status XlaContext::AddConstRetval(int retval_index, DataType dtype, + const xla::Literal& literal) { + VLOG(1) << "Adding retval index " << retval_index + << " with non-data-dependent tensor to XLA computation"; + ConstRetVal value; + value.index = retval_index; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value)); + mutex_lock l(mu_); + compile_time_constant_.push_back(std::move(value)); + return Status::OK(); +} + +/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor( + const Tensor& tensor) { + const XlaExpression* expression = + reinterpret_cast(tensor.tensor_data().data()); + CHECK_NE(expression->handle().handle(), 0); + VLOG(1) << "Fetched T" << expression->handle().handle(); + return expression; +} + +/* static */ XlaExpression* XlaContext::CastExpressionFromUninitializedTensor( + Tensor* tensor) { + const XlaExpression* expression = + reinterpret_cast(tensor->tensor_data().data()); + CHECK_EQ(expression->handle().handle(), 0); + return const_cast(expression); +} + +/* static */ const XlaExpression* XlaContext::GetExpressionFromTensor( + const Tensor& tensor) { + return CastExpressionFromTensor(tensor); +} + +/* static */ const xla::ComputationDataHandle& +XlaContext::GetComputationFromTensor(const Tensor& tensor) { + return CastExpressionFromTensor(tensor)->handle(); +} + +xla::ComputationBuilder& XlaContext::builder() { return xla_builder_; } + +const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { + return LookupOrCreate(type, &max_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Max() for " << type_string; + xla::ComputationBuilder b(builder().client(), "max<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Max(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + +const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { + return LookupOrCreate(type, &add_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Add() for " << type_string; + xla::ComputationBuilder b(builder().client(), "add<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Add(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + +const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) { + return LookupOrCreate(type, &sigmoid_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Sigmoid() for " << type_string; + xla::ComputationBuilder b(builder().client(), + "sigmoid<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto one = b.ConstantLiteral(xla::LiteralUtil::One(xla_type)); + auto minus_one = b.Neg(one); + b.Div(one, b.Add(b.Exp(b.Mul(x, minus_one)), one)); + return b.Build().ConsumeValueOrDie(); + }); +} + +const xla::Computation* XlaContext::LookupOrCreate( + DataType type, ComputationMap* out, + const std::function& create) { + { + mutex_lock l(mu_); + const auto& entry = (*out)[type]; + if (!entry.IsNull()) { + return &entry; + } + } + auto new_entry = create(); + { + mutex_lock l(mu_); + // Somebody else might have made one concurrently. + auto& entry = (*out)[type]; + if (entry.IsNull()) { + entry = std::move(new_entry); + } + return &entry; + } +} + +} // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h new file mode 100644 index 0000000000..b0464025f7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the contexts used to represent XLA JIT computatations. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// A XlaExpression wraps an XLA computation. Each Tensor sent +// along an edge during XLA JIT compilation represents a +// XlaExpression, and the shape of the Tensor matches the shape of +// the subcomputation in the ComputationDataHandle. Each +// expression is either a constant, an unbound parameter, or a +// function of previously-compiled expressions. +class XlaExpression { + public: + XlaExpression(); + + // handle() stores the XLA handle of the computation that the + // expression represents. + void set_handle(const xla::ComputationDataHandle& h); + const xla::ComputationDataHandle& handle() const; + + void set_constant_value(Tensor value); + bool has_constant_value() const { return has_constant_value_; } + const Tensor& constant_value() const { return constant_value_; } + + private: + friend class XlaContext; + + // The XLA handle of the expression's computation. + xla::ComputationDataHandle handle_; + + // If this expression is a constant with a known value, 'constant_value' is a + // host-memory Tensor containing the value. Used to avoid invoking XLA for + // expressions that are trivially constant. + bool has_constant_value_; + Tensor constant_value_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); +}; + +// The XlaContext is the datastructure accessible from +// OpKernelContexts when evaluating a subgraph of Ops for JIT +// compilation by XLA. When an Op is executed during JIT +// compilation the input Tensors to the Op store handles to +// subcomputations compiled by earlier Ops in the subgraph. The Op can +// retrieve these subcomputations by calling either +// GetExpressionFromTensor, which returns the XlaExpression holding +// the subcomputation; or EvaluateAsConstant which returns an XLA +// literal of the result of the subcomputation or an error status if +// the subcomputation depends on unbound parameters. The Op may then +// use the ComputationBuilder available from XlaContext::builder() +// to compile one or more functions of the inputs into +// ComputationDataHandles. The handles can be stored as new +// expressions corresponding to the outputs of the Op by calling +// CreateOutputTensorFromComputation or +// CreateConstantOutputTensor. The *only* correct way to allocate an +// output tensor is using one of the preceding two methods, since they +// ensure there is a valid XlaExpression backing the output +// tensor. No Op should ever call allocate_output or allocate_temp +// directly on the OpKernelContext. It is permissible to pass a tensor +// from an Op input to an output (e.g. call ctx->set_output with a +// tensor passed as an input). As an example, the softmax Op produces +// output from input as follows: +// +// XlaContext& tc = XlaContext::Get(context); +// xla::ComputationBuilder& b = tc.builder(); +// xla::ComputationDataHandle logits = +// tc.GetComputationFromTensor(logits_in)); +// ... The softmax computation uses the builder b to compute a +// xla::ComputationDataHandle softmax holding the desired output. +// ... +// OP_REQUIRES_OK(context, tc.CreateOutputTensorFromComputation( +// context, 0, logits_in.shape().dim_sizes(), +// softmax)); +// +class XlaContext : public ResourceBase { + public: + // If a retval can be evaluated at JIT time it is returned as a + // Literal in a ConstRetVal struct as part of the ComputationResult. + // TODO(misard) reconcile this with the duplicate data structure in + // the XlaCompilationCache class. + struct ConstRetVal { + // The index of the RetVal corresponding to this constant literal. + int index; + // If status is not OK, value's data is undefined. + Status status; + // The value of the RetVal evaluated at JIT compilation + // time. value.shape() always gives the correct shape of the + // RetVal. If !status.ok() then value's data is undefined, otherwise the + // Tensor buffer is allocated in CPU memory. + Tensor value; + }; + + + // Virtual method defined by ResourceBase. + string DebugString() override; + + // Retrieve the XlaContext corresponding to a step's JIT compilation. + static XlaContext& Get(const OpKernelContext* ctx); + static XlaContext& Get(const XlaOpKernelContext* ctx) { + return Get(ctx->op_kernel_context()); + } + + // Create a new XlaContext. + XlaContext(xla::Client* client, const string& computation_name, + bool allow_cpu_custom_calls); + + // Builds XLA computations for each of the arguments. + // Should only be called once to initialize the arguments. Not thread-safe. + Status BuildArguments(std::vector arguments, + bool use_tuple_arg) TF_MUST_USE_RESULT; + + // Returns the results of the symbolic computation that have accumulated in + // the XlaContext. After CollectResults() is called, the context is left in + // an invalid state and must not be reused. + // Sets `requires_runtime_context` if the emitted computation requires a + // runtime context argument. `compile_time_constants` describes any non + // data-dependent results of the computation. `num_nonconst_ouputs` is set to + // the number of outputs of the `computation`. + Status CollectResults(xla::Computation* computation, + bool* requires_runtime_context, + std::vector* compile_time_constants, + int* num_nonconst_outputs); + + // This is called by the Retval Op to associate a computed value + // with a specific return value of the subgraph. + void AddRetval(int retval_index, const xla::ComputationDataHandle& handle); + + // As for Retval, but for return values that are compile-time constants. + Status AddConstRetval(int retval_index, DataType dtype, + const xla::Literal& literal); + + // Retrieves the ComputationDataHandle from an input Tensor to an Op. This + // computation was constructed by an Op that executed previously and + // created the output Tensor using CreateOutputTensorFromComputation + // or CreateConstantOutputTensor. + static const xla::ComputationDataHandle& GetComputationFromTensor( + const Tensor& tensor); + + // Returns the ComputationBuilder that Ops use for compiling new + // expressions. + xla::ComputationBuilder& builder(); + + const std::vector& args() const { return args_; } + xla::ComputationDataHandle parameter(int num) { return parameters_[num]; } + + // Get the runtime context parameter, adding one if it does not already exist. + // Dies if not compiling a local executable. + const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); + + bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + + // Get an XLA lambda to compute Max. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMax(const DataType type); + + // Get an XLA lambda to compute Add. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateAdd(const DataType type); + + // Get an XLA lambda to compute Sigmoid. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateSigmoid(const DataType type); + + // The name of the XlaContext resource during symbolic graph execution. + static const char kXlaContextResourceName[]; + + private: + friend class XlaOpKernelContext; + + // This method is used to retrieve an expression that was allocated by + // a previous Op. + static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); + + // This method is used to retrieve an uninitialized expression from a + // newly-allocated tensor. + static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor); + + // Retrieves the expression from an input Tensor to an Op. This + // expression was constructed by an Op that executed previously and + // created the output Tensor using CreateOutputTensorFromComputation + // or CreateConstantOutputTensor. + static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor); + + mutable mutex mu_; + + // The ComputationBuilder used to construct the subgraph's compiled + // representation. + xla::ComputationBuilder xla_builder_ GUARDED_BY(mu_); + + // Number of XLA Parameters, not counting the context parameter, if any. + int num_parameters_; + + // Arguments to the JIT compilation, both compile-time constant arguments and + // runtime parameters. + std::vector args_; + bool use_tuple_arg_ = false; + + // Runtime parameters to the XLA computation. Does not include + // compile-time constant arguments. + std::vector parameters_; + + // Allow ops to emit CustomCall operations for CPU. + const bool allow_cpu_custom_calls_; + + // When 'has_context_parameter_' is true, this is the computation handle + // for an additional final parameter to the computation, through which will be + // passed a XlaLocalRuntimeContext* at runtime. Created on demand by + // GetOrCreateRuntimeContextParameter(). + bool has_context_parameter_ GUARDED_BY(mu_) = false; + xla::ComputationDataHandle context_parameter_ GUARDED_BY(mu_); + + // The data-dependent return values of the computation. + std::vector> retval_ + GUARDED_BY(mu_); + + // The non-data-dependent return values of the computation. + std::vector compile_time_constant_ GUARDED_BY(mu_); + + // Cache of prebuilt computations indexed by their type. + using ComputationMap = std::map; + + // Finds the value for the given type in out map if it already + // exists or makes a new value with create function and keeps it the + // map. The returned value != nullptr and is owned by the map. + const xla::Computation* LookupOrCreate( + DataType type, ComputationMap* out, + const std::function& create) LOCKS_EXCLUDED(mu_); + + // Cached computation to compute Max of two elements, specialized by type. + ComputationMap max_func_ GUARDED_BY(mu_); + + // Cached computation to compute Sum of two elements, specialized by type. + ComputationMap add_func_ GUARDED_BY(mu_); + + // Cached computation to compute Sigmoid of an element, specialized by type. + ComputationMap sigmoid_func_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaContext); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc new file mode 100644 index 0000000000..efb0facf7b --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines helper routines for Tla JIT compilation. + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, + DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + return b->ConstantLiteral(xla::LiteralUtil::MinValue(type)); +} + +xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, + DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); +} + +xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, + DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + return b->ConstantLiteral(xla::LiteralUtil::Zero(type)); +} + +xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, + DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + return b->ConstantLiteral(xla::LiteralUtil::One(type)); +} + +xla::ComputationDataHandle XlaHelpers::IntegerLiteral( + xla::ComputationBuilder* b, DataType data_type, int64 value) { + xla::Literal literal; + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + switch (type) { + case xla::U8: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::U32: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::U64: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::S8: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::S32: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::S64: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::F32: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::F64: + literal = *xla::LiteralUtil::CreateR0(value); + break; + case xla::PRED: + LOG(FATAL) << "pred element type is not integral"; + case xla::S16: + case xla::U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::F16: + LOG(FATAL) << "f16 literals not yet implemented"; + case xla::TUPLE: + LOG(FATAL) << "tuple element type is not integral"; + case xla::OPAQUE: + LOG(FATAL) << "opaque element type is not integral"; + default: + LOG(FATAL) << "unhandled element type " << type; + } + return b->ConstantLiteral(literal); +} + +xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, + DataType data_type, + double value) { + xla::Literal literal; + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + switch (type) { + case xla::F32: + return b->ConstantR0(static_cast(value)); + break; + case xla::F64: + return b->ConstantR0(value); + break; + default: + LOG(FATAL) << "unhandled element type " << type; + } +} + +/* static */ Status XlaHelpers::ReshapeLiteral( + const xla::Literal& input, gtl::ArraySlice dimensions, + xla::Literal* output) { + if (xla::ShapeUtil::IsTuple(input.shape())) { + return errors::InvalidArgument("ReshapeLiteral does not support tuples."); + } + xla::Shape shape = + xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions); + int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape()); + int64 elements_after = xla::ShapeUtil::ElementsIn(shape); + if (elements_before != elements_after) { + return errors::InvalidArgument( + "Shapes before and after ReshapeLiteral have different numbers of " + "elements."); + } + + *output = input; + output->mutable_shape()->Swap(&shape); + return Status::OK(); +} + +} // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h new file mode 100644 index 0000000000..353ed02edd --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines helper routines for the TLA device. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ + +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Helper methods for building XLA computations. +class XlaHelpers { + public: + // Returns a handle representing the minimum value of a scalar + // element of data_type. + static xla::ComputationDataHandle MinValue(xla::ComputationBuilder* b, + DataType data_type); + + // Returns a handle representing the maximum value of a scalar + // element of data_type. + static xla::ComputationDataHandle MaxValue(xla::ComputationBuilder* b, + DataType data_type); + + // Returns a handle representing the zero value of a scalar + // element of data_type. + static xla::ComputationDataHandle Zero(xla::ComputationBuilder* b, + DataType data_type); + + // Returns a handle representing the one value of a scalar + // element of data_type. + static xla::ComputationDataHandle One(xla::ComputationBuilder* b, + DataType data_type); + + // Returns a handle representing the given value of an integer scalar + // element of data_type. + // Note that unlike One and Zero, does not work on boolean types. + static xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* b, + DataType data_type, + int64 value); + + // Returns a handle representing the given value of a floating-point scalar + // element of data_type. + static xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* b, + DataType data_type, + double value); + + // Reshapes literal 'input' to have 'shape'. Both the original shape and + // 'shape' must contain the same number of elements. + static Status ReshapeLiteral(const xla::Literal& input, + gtl::ArraySlice shape, + xla::Literal* output); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h new file mode 100644 index 0000000000..cd773d64ed --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's +// actually used. E.g. some ahead-of-time compiled computations don't need a +// thread pool. +namespace Eigen { +class ThreadPoolDevice; +} + +namespace tensorflow { + +// An instance of this class is passed to each call from tensorflow into a +// compiled XLA computation. See xla_launch_ops.cc. +struct XlaLocalRuntimeContext { + public: + XlaLocalRuntimeContext() {} + + // Kernels implemented using custom call ops set this if they encounter an + // error. The error is checked after the entire XLA computation is + // complete. + // + // error+error_msg are used instead of Status to reduce the binary size + // overhead for ahead-of-time compiled binaries. + bool error = false; + string error_msg; + + // Kernels that need a thread pool can get it from here. + const Eigen::ThreadPoolDevice* thread_pool = nullptr; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc new file mode 100644 index 0000000000..3883b907b4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -0,0 +1,253 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" + +#include + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" + +namespace tensorflow { + +XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context) + : context_(context) {} + +bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { + return context_->ValidateInputsAreSameShape(op); +} + +xla::ComputationBuilder* XlaOpKernelContext::builder() const { + return &XlaContext::Get(this).builder(); +} + +const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { + return XlaContext::GetComputationFromTensor(context_->input(index)); +} + +TensorShape XlaOpKernelContext::InputShape(int index) { + return context_->input(index).shape(); +} + +Status XlaOpKernelContext::ConstantInput(int index, + xla::Literal* constant_literal) { + return ConstantInputReshaped( + index, context_->input(index).shape().dim_sizes(), constant_literal); +} + +Status XlaOpKernelContext::ConstantInputReshaped( + int index, gtl::ArraySlice new_dims, + xla::Literal* constant_literal) { + const Tensor& tensor = context_->input(index); + TensorShape new_shape(new_dims); + if (tensor.NumElements() != new_shape.num_elements()) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + tensor.shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + new_shape.DebugString()); + } + const XlaExpression* expression = + XlaContext::CastExpressionFromTensor(tensor); + + // If the tensor has a known constant value, there is no need to invoke XLA. + if (expression->has_constant_value()) { + Tensor temp(tensor.dtype()); + if (!temp.CopyFrom(expression->constant_value(), new_shape)) { + // This should never happen. The constant should have a shape compatible + // with the enclosing Tensor. + return errors::Internal("Incompatible shapes in ConstantInputReshaped."); + } + return HostTensorToLiteral(temp, constant_literal); + } + + // Make sure we treat zero-element tensors as constant. + if (new_shape.num_elements() == 0) { + Tensor temp(tensor.dtype(), new_shape); + return HostTensorToLiteral(temp, constant_literal); + } + + xla::ComputationDataHandle handle = expression->handle(); + if (new_shape != tensor.shape()) { + // Reshape the handle to the desired shape. + handle = builder()->Reshape(handle, new_shape.dim_sizes()); + } + + // The XLA layout is specified minor to major, and TensorFlow's minor + // dimension is the last one. + std::vector layout_indices(new_shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + + // Ask the XLA compiler to evaluate the data handle to a literal. + xla::StatusOr> computed = + builder()->ComputeConstant(handle, &layout); + if (!computed.ok()) { + return errors::InvalidArgument( + "Error evaluating ", context_->op_kernel().name(), " input ", index, + ": ", computed.status().error_message()); + } + // Fetch the literal from the compiler service. + xla::StatusOr> constant = + builder()->client()->Transfer(*computed.ValueOrDie()); + if (!constant.ok()) { + return errors::InvalidArgument( + "Error evaluating ", context_->op_kernel().name(), " input ", index, + ": ", constant.status().error_message()); + } + constant_literal->Swap(constant.ValueOrDie().get()); + return Status::OK(); +} + +// Converts an int32 or int64 1D literal to an int64 vector. +static Status LiteralToInt64Vector(const xla::Literal& literal, + std::vector* out) { + if (xla::ShapeUtil::Rank(literal.shape()) != 1) { + return errors::InvalidArgument("value is not 1D"); + } + int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); + if (literal.shape().element_type() == xla::S32) { + for (int64 i = 0; i < size; ++i) { + out->push_back(xla::LiteralUtil::Get(literal, {i})); + } + } else if (literal.shape().element_type() == xla::S64) { + for (int64 i = 0; i < size; ++i) { + out->push_back(xla::LiteralUtil::Get(literal, {i})); + } + } else { + return errors::InvalidArgument("value must be either int32 or int64"); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ConstantInputAsIntVector(int index, + std::vector* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + return LiteralToInt64Vector(literal, out); +} + +// TODO(phawkins): validate that the dimensions form a valid shape, fail +// gracefully if they do not. +Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + std::vector dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = TensorShape(dims); + return Status::OK(); +} + +Status XlaOpKernelContext::InputList( + StringPiece name, std::vector* handles, + std::vector* shapes) { + OpInputList inputs; + TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); + handles->clear(); + shapes->clear(); + for (const Tensor& input : inputs) { + handles->push_back(XlaContext::GetComputationFromTensor(input)); + shapes->push_back(input.shape()); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ConstantInputList( + StringPiece name, std::vector* outputs) { + int start, stop; + TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); + outputs->resize(stop - start); + for (int i = start; i < stop; ++i) { + TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); + } + return Status::OK(); +} + +void XlaOpKernelContext::SetOutput(int index, + const xla::ComputationDataHandle& handle) { + // Makes the host Tensor that will refer to the expression. + Tensor* output = nullptr; + auto shape = builder()->GetShape(handle); + if (!shape.ok()) { + SetStatus(shape.status()); + return; + } + + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + OP_REQUIRES_OK( + context_, + context_->allocate_output( + index, XLAShapeToTensorShape(*shape.ValueOrDie()), &output)); + + // The expression is stored in the tensor's data buffer. Fill in the + // fields now. + XlaExpression* expression = + XlaContext::CastExpressionFromUninitializedTensor(output); + expression->set_handle(handle); +} + +void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { + const TensorShape& shape = constant.shape(); + + xla::Literal literal; + OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); + xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); + + // Make the Tensor that will refer to the expression. + Tensor* output = nullptr; + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); + + // The expression is stored in the tensor's data buffer. Fill in the + // fields now. + XlaExpression* expression = + XlaContext::CastExpressionFromUninitializedTensor(output); + expression->set_handle(handle); + expression->set_constant_value(constant); +} + +void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } +void XlaOpKernelContext::CtxFailureWithWarning(Status s) { + context_->CtxFailureWithWarning(s); +} + +const xla::Computation* XlaOpKernelContext::GetOrCreateMax( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMax(type); +} + +const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateAdd(type); +} + +const xla::Computation* XlaOpKernelContext::GetOrCreateSigmoid( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateSigmoid(type); +} + +XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} + +void XlaOpKernel::Compute(OpKernelContext* context) { + XlaOpKernelContext xla_context(context); + Compile(&xla_context); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h new file mode 100644 index 0000000000..0c614005be --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -0,0 +1,174 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class XlaOpKernelContext; + +// Implementations of operators that generate XLA code should usually subclass +// XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel, +// an XlaOpKernel produces and consumes symbolic values during compilation. +// +// See the comments in xla_context.h for more details. +class XlaOpKernel : public OpKernel { + public: + explicit XlaOpKernel(OpKernelConstruction* construction); + + // Subclasses should implement Compile(), much as standard OpKernels implement + // Compute(). + virtual void Compile(XlaOpKernelContext* context) = 0; + + private: + void Compute(OpKernelContext* context) final; +}; + +// The context passed to the Compile() method of XlaOpKernel. An +// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for +// implementing operators that perform symbolic execution as part of the XLA +// compiler. The key difference is that XlaOpKernelContext produces and consumes +// data as XLA computations, rather than as standard Tensors. (Under the hood, +// symbolic execution communicates using special Tensors, but that is an +// implementation detail that this class hides.) +class XlaOpKernelContext { + public: + explicit XlaOpKernelContext(OpKernelContext* context); + + // Returns the XLA ComputationBuilder containing the output of compilation. + xla::ComputationBuilder* builder() const; + + // Inputs + + // Returns the number of inputs to the operator. + int num_inputs() const { return context_->num_inputs(); } + + // Returns the type of input 'index'. + DataType input_type(int index) { return context_->input(index).dtype(); } + + // Returns the shape of input 'index'. + TensorShape InputShape(int index); + + // Returns input 'index' as a ComputationDataHandle. Unlike + // OpKernelContext::Input returns a symbolic value rather than a concrete + // Tensor. + const xla::ComputationDataHandle& Input(int index); + + // Returns true if all inputs are the same shape, otherwise sets the + // status to a non-OK value and returns false. + // Usage: if (!context->ValidateInputsAreSameShape(this)) return; + bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT; + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. + Status InputList(StringPiece name, + std::vector* handles, + std::vector* shapes); + + // Helper methods for constant inputs. + + // Evaluates input 'index' and stores it in '*constant_literal'. If the + // expression cannot be evaluated, e.g., because it depends on unbound + // parameters, returns a non-OK status. + Status ConstantInput(int index, xla::Literal* constant_literal); + + // Evaluates input 'index', reshapes it to 'new_shape' if new_shape != + // InputShape(index), and stores it in '*constant_literal'. If the input + // cannot be evaluated, e.g., because it depends on unbound parameters, + // returns a non-Ok status. If InputShape(index).num_elements() != + // new_shape.num_elements(), returns an error status. + Status ConstantInputReshaped(int index, gtl::ArraySlice new_shape, + xla::Literal* constant_literal); + + // Converts a constant 1D int32 or int64 tensor into a vector of int64s. + Status ConstantInputAsIntVector(int index, std::vector* out); + + // Converts a constant 1D int32 or int64 tensor into a TensorShape. + Status ConstantInputAsShape(int index, TensorShape* shape); + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. + Status ConstantInputList(StringPiece name, + std::vector* literals); + + // Outputs + + int num_outputs() const { return context_->num_outputs(); } + DataType expected_output_dtype(int index) const { + return context_->expected_output_dtype(index); + } + + // Sets output 'index' to the ComputationDataHandle 'handle'. + // All outputs should be set using SetOutput and SetConstantOutput, not + // via the underlying OpKernelContext. + void SetOutput(int index, const xla::ComputationDataHandle& handle); + + // Sets output 'index' to compile-time constant 'host_tensor', where + // 'host_tensor' is a tensor in host memory. It is preferable to use + // SetConstantOutput where possible. + void SetConstantOutput(int index, const Tensor& host_tensor); + + // Status handling. + void SetStatus(const Status& status) { context_->SetStatus(status); } + Status status() { return context_->status(); } + + // Helper routines for the OP_REQUIRES macros + void CtxFailure(Status s); + void CtxFailureWithWarning(Status s); + + // If this kernel invocation is within a function execution, + // call_frame() returns the call frame for the function call. + FunctionCallFrame* call_frame() const { return context_->call_frame(); } + + FunctionLibraryRuntime* function_library() const { + return context_->function_library(); + } + + const OpKernel& op_kernel() const { return context_->op_kernel(); } + + // Returns the underlying OpKernelContext. Use rarely. + OpKernelContext* op_kernel_context() const { return context_; } + + // TODO(phawkins): find a better home for these helpers. + + // Get an XLA lambda to compute Max. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMax(const DataType type); + + // Get an XLA lambda to compute Add. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateAdd(const DataType type); + + // Get an XLA lambda to compute Sigmoid. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateSigmoid(const DataType type); + + private: + OpKernelContext* const context_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ diff --git a/tensorflow/compiler/xla/.clang-format b/tensorflow/compiler/xla/.clang-format new file mode 100644 index 0000000000..c2aa867556 --- /dev/null +++ b/tensorflow/compiler/xla/.clang-format @@ -0,0 +1,3 @@ +BasedOnStyle: Google +Language: Cpp +PointerBindsToType: true diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD new file mode 100644 index 0000000000..caa5141da4 --- /dev/null +++ b/tensorflow/compiler/xla/BUILD @@ -0,0 +1,561 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +package_group( + name = "friends", + packages = [ + "//tensorflow/compiler/...", + ], +) + +package_group( + name = "internal", + packages = [ + "//tensorflow/compiler/xla/...", + ], +) + +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +xla_proto_library( + name = "xla_data_proto", + srcs = ["xla_data.proto"], + visibility = ["//visibility:public"], +) + +xla_proto_library( + name = "xla_proto", + srcs = ["xla.proto"], + visibility = ["//visibility:public"], + deps = [ + ":xla_data_proto", + "//tensorflow/compiler/xla/service:session_proto", + ], +) + +cc_library( + name = "types", + hdrs = ["types.h"], + visibility = [":friends"], + deps = ["//tensorflow/core:lib"], +) + +cc_library( + name = "service_interface", + srcs = [], + hdrs = ["service_interface.h"], + visibility = [":friends"], + deps = [ + ":xla_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "status_macros", + srcs = ["status_macros.cc"], + hdrs = ["status_macros.h"], + visibility = [":friends"], + deps = [ + ":statusor", + ":types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "status_macros_test", + size = "small", + srcs = ["status_macros_test.cc"], + deps = [ + ":status_macros", + ":statusor", + ":test_helpers", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "status", + hdrs = ["status.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "statusor", + srcs = ["statusor.cc"], + hdrs = ["statusor.h"], + visibility = ["//visibility:public"], + deps = [ + ":status", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_test( + name = "statusor_test", + size = "small", + srcs = ["statusor_test.cc"], + deps = [ + ":statusor", + ":types", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = [ + "map_util.h", + "ptr_util.h", + "util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":status", + ":types", + ":xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:util_flags", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "protobuf_util", + srcs = ["protobuf_util.cc"], + hdrs = [ + "protobuf_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":types", + ":util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "shape_util", + srcs = [ + "index_util.cc", + "layout_util.cc", + "primitive_util.cc", + "shape_util.cc", + ], + hdrs = [ + "index_util.h", + "layout_util.h", + "primitive_util.h", + "shape_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":protobuf_util", + ":status_macros", + ":statusor", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_test( + name = "shape_util_test", + srcs = ["shape_util_test.cc"], + deps = [ + ":shape_util", + ":test_helpers", + ":types", + ":util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "layout_util_test", + srcs = ["layout_util_test.cc"], + deps = [ + ":shape_util", + ":test_helpers", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "index_util_test", + srcs = ["index_util_test.cc"], + deps = [ + ":shape_util", + ":test_helpers", + ":xla_data_proto", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "literal_util", + srcs = ["literal_util.cc"], + hdrs = ["literal_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":array2d", + ":array3d", + ":array4d", + ":shape_util", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "literal_util_test", + srcs = ["literal_util_test.cc"], + deps = [ + ":array3d", + ":array4d", + ":literal_util", + ":shape_util", + ":test_helpers", + ":types", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "device_util", + hdrs = ["device_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "array2d", + srcs = ["array2d.cc"], + hdrs = ["array2d.h"], + visibility = ["//visibility:public"], + deps = [ + ":types", + ":util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "array2d_test", + srcs = ["array2d_test.cc"], + deps = [ + ":array2d", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "array3d", + hdrs = ["array3d.h"], + visibility = [":friends"], + deps = [ + ":types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "array3d_test", + srcs = ["array3d_test.cc"], + deps = [ + ":array3d", + ":types", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "array4d", + hdrs = ["array4d.h"], + visibility = [":friends"], + deps = [ + ":array2d", + ":types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "array4d_test", + srcs = ["array4d_test.cc"], + deps = [ + ":array4d", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "executable_run_options", + srcs = ["executable_run_options.cc"], + hdrs = ["executable_run_options.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "differential_set", + hdrs = ["differential_set.h"], + visibility = [":internal"], + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "differential_set_test", + srcs = ["differential_set_test.cc"], + deps = [ + ":differential_set", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "packed_literal_reader", + srcs = ["packed_literal_reader.cc"], + hdrs = ["packed_literal_reader.h"], + visibility = [":internal"], + deps = [ + ":literal_util", + ":shape_util", + ":status_macros", + ":statusor", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_helpers", + testonly = 1, + srcs = ["test_helpers.cc"], + hdrs = ["test_helpers.h"], + visibility = [":internal"], + deps = [ + ":statusor", + ":types", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "text_literal_reader", + srcs = ["text_literal_reader.cc"], + hdrs = ["text_literal_reader.h"], + visibility = [":internal"], + deps = [ + ":literal_util", + ":shape_util", + ":status_macros", + ":statusor", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_test( + name = "text_literal_reader_test", + srcs = ["text_literal_reader_test.cc"], + deps = [ + ":literal_util", + ":shape_util", + ":text_literal_reader", + ":types", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "text_literal_writer", + srcs = ["text_literal_writer.cc"], + hdrs = ["text_literal_writer.h"], + visibility = [":internal"], + deps = [ + ":literal_util", + ":shape_util", + ":status_macros", + ":types", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "text_literal_writer_test", + srcs = ["text_literal_writer_test.cc"], + deps = [ + ":literal_util", + ":test_helpers", + ":text_literal_writer", + ":types", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "shape_tree", + hdrs = ["shape_tree.h"], + visibility = ["//visibility:public"], + deps = [ + ":shape_util", + ":status_macros", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "shape_tree_test", + srcs = ["shape_tree_test.cc"], + deps = [ + ":shape_tree", + ":shape_util", + ":xla_data_proto", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "shape_layout", + srcs = ["shape_layout.cc"], + hdrs = ["shape_layout.h"], + visibility = ["//visibility:public"], + deps = [ + ":shape_util", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "window_util", + srcs = ["window_util.cc"], + hdrs = ["window_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":types", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "reference_util", + srcs = ["reference_util.cc"], + hdrs = ["reference_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":array2d", + ":array3d", + ":array4d", + ":util", + ":window_util", + ":xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "reference_util_test", + srcs = ["reference_util_test.cc"], + deps = [ + ":array2d", + ":array4d", + ":literal_util", + ":reference_util", + ":util", + ":xla_data_proto", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md new file mode 100644 index 0000000000..c93c39e180 --- /dev/null +++ b/tensorflow/compiler/xla/README.md @@ -0,0 +1 @@ +This is the home of XLA. diff --git a/tensorflow/compiler/xla/array2d.cc b/tensorflow/compiler/xla/array2d.cc new file mode 100644 index 0000000000..418587c1f7 --- /dev/null +++ b/tensorflow/compiler/xla/array2d.cc @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/ptr_util.h" + +namespace xla { + +std::unique_ptr> MakeLinspaceArray2D(float from, float to, + int64 n1, int64 n2) { + auto array = MakeUnique>(n1, n2); + int64 count = n1 * n2; + float step = (count > 1) ? (to - from) / (count - 1) : 0.0f; + auto set = [&array, n1, n2](int64 index, float value) { + (*array)(index / n2, index % n2) = value; + }; + for (int64 i = 0; i < count - 1; ++i) { + set(i, from + i * step); + } + set(count - 1, to); + return array; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h new file mode 100644 index 0000000000..ceed573f1f --- /dev/null +++ b/tensorflow/compiler/xla/array2d.h @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ARRAY2D_H_ +#define TENSORFLOW_COMPILER_XLA_ARRAY2D_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Simple 2D array structure. +// +// The data layout in major-to-minor order is: n1, n2. +template +class Array2D { + public: + // Creates an empty array. + Array2D() : n1_(0), n2_(0) {} + + // Creates an array of dimensions n1 x n2, uninitialized values. + Array2D(const int64 n1, const int64 n2) : n1_(n1), n2_(n2) { + values_.resize(n1 * n2); + } + + // Creates an array of dimensions n1 x n2, initialized to value. + Array2D(const int64 n1, const int64 n2, const T value) : Array2D(n1, n2) { + Fill(value); + } + + // Creates an array from the given nested initializer list. The outer + // initializer list is the first dimension; the inner is the second dimension. + // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. + Array2D(std::initializer_list> values) + : Array2D(values.size(), values.begin()->size()) { + int64 n1 = 0; + for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) { + int64 n2 = 0; + for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) { + (*this)(n1, n2) = *n2_it; + } + } + } + + T& operator()(const int64 n1, const int64 n2) { + CHECK_LT(n1, n1_); + CHECK_LT(n2, n2_); + return values_[n1 * n2_ + n2]; + } + + const T& operator()(const int64 n1, const int64 n2) const { + CHECK_LT(n1, n1_); + CHECK_LT(n2, n2_); + return values_[n1 * n2_ + n2]; + } + + // Access to the array's dimensions. height() and width() provide the + // canonical interpretation of the array n1 x n2 having n1 rows of n2 columns + // each (height is number of rows; width is number of columns). + int64 n1() const { return n1_; } + int64 n2() const { return n2_; } + int64 height() const { return n1_; } + int64 width() const { return n2_; } + int64 num_elements() const { return values_.size(); } + + // Low-level accessor for stuff like memcmp, handle with care. Returns pointer + // to the underlying storage of the array (similarly to std::vector::data()). + T* data() const { return const_cast(this)->values_.data(); } + + // Fills the array with the given value. + void Fill(const T& value) { + std::fill(values_.begin(), values_.end(), value); + } + + // Applies f to all cells in this array, in row-major order. + void Each(std::function f) { + for (int64 i0 = 0; i0 < n1(); ++i0) { + for (int64 i1 = 0; i1 < n2(); ++i1) { + f(i0, i1, &(*this)(i0, i1)); + } + } + } + + // Fills the array with a pattern of values of the form: + // + // (rowno << log2ceil(width) | colno) + start_value + // + // This makes it easy to see distinct row/column values in the array. + void FillUnique(T start_value = 0) { + for (int64 i0 = 0; i0 < n1(); ++i0) { + for (int64 i1 = 0; i1 < n2(); ++i1) { + (*this)(i0, i1) = + ((i0 << tensorflow::Log2Ceiling64(n2())) | i1) + start_value; + } + } + } + + // Fills the array with random normal variables of deviation value. + void FillRandom(const T& value, const double mean = 0.0, + const int seed = 12345) { + std::mt19937 g(seed); + std::normal_distribution distribution(mean, + static_cast(value)); + for (auto& v : values_) { + v = static_cast(distribution(g)); + } + } + + // Returns a readable string representation of the array. + string ToString() const { + std::vector pieces = {"["}; + for (int64 row = 0; row < height(); ++row) { + pieces.push_back("["); + for (int64 col = 0; col < width(); ++col) { + pieces.push_back(tensorflow::strings::StrCat((*this)(row, col))); + pieces.push_back(", "); + } + pieces.pop_back(); + pieces.push_back("]"); + pieces.push_back(",\n "); + } + pieces.pop_back(); + pieces.push_back("]"); + return tensorflow::str_util::Join(pieces, ""); + } + + private: + int64 n1_; + int64 n2_; + std::vector values_; +}; + +// Returns a linspace-populated Array2D in the range [from, to] (inclusive) +// with dimensions n1 x n2. +std::unique_ptr> MakeLinspaceArray2D(float from, float to, + int64 n1, int64 n2); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ARRAY2D_H_ diff --git a/tensorflow/compiler/xla/array2d_test.cc b/tensorflow/compiler/xla/array2d_test.cc new file mode 100644 index 0000000000..ac107b1c0d --- /dev/null +++ b/tensorflow/compiler/xla/array2d_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array2d.h" + +#include + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(Array2dTest, DefaultCtor) { + Array2D empty; + EXPECT_EQ(empty.n1(), 0); + EXPECT_EQ(empty.n2(), 0); + EXPECT_EQ(empty.num_elements(), 0); +} + +TEST(Array2dTest, UninitializedDimsCtor) { + Array2D uninit(2, 3); + EXPECT_EQ(uninit.n1(), 2); + EXPECT_EQ(uninit.n2(), 3); + EXPECT_EQ(uninit.num_elements(), 6); +} + +TEST(Array2dTest, FillCtor) { + Array2D fullof7(2, 3, 7); + + EXPECT_EQ(fullof7.n1(), 2); + EXPECT_EQ(fullof7.n2(), 3); + + for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { + for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 7); + } + } +} + +TEST(Array2dTest, InitializerListCtor) { + Array2D arr = {{1, 2, 3}, {4, 5, 6}}; + + EXPECT_EQ(arr.n1(), 2); + EXPECT_EQ(arr.n2(), 3); + + EXPECT_EQ(arr(0, 0), 1); + EXPECT_EQ(arr(0, 1), 2); + EXPECT_EQ(arr(0, 2), 3); + EXPECT_EQ(arr(1, 0), 4); + EXPECT_EQ(arr(1, 1), 5); + EXPECT_EQ(arr(1, 2), 6); +} + +TEST(Array2dTest, Accessors) { + Array2D arr = {{1, 2, 3}, {4, 5, 6}}; + + EXPECT_EQ(arr.n1(), 2); + EXPECT_EQ(arr.n2(), 3); + EXPECT_EQ(arr.height(), 2); + EXPECT_EQ(arr.width(), 3); + EXPECT_EQ(arr.num_elements(), 6); +} + +TEST(Array2dTest, IndexingReadWrite) { + Array2D arr = {{1, 2, 3}, {4, 5, 6}}; + + EXPECT_EQ(arr(1, 1), 5); + EXPECT_EQ(arr(1, 2), 6); + arr(1, 1) = 51; + arr(1, 2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + +TEST(Array2dTest, Fill) { + Array2D fullof7(2, 3, 7); + for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { + for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 7); + } + } + + fullof7.Fill(11); + for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { + for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 11); + } + } +} + +TEST(Array2dTest, DataPointer) { + Array2D arr = {{1, 2, 3}, {4, 5, 6}}; + + EXPECT_EQ(arr.data()[0], 1); +} + +TEST(Array2dTest, Linspace) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ((*arr)(0, 0), 1.0); + EXPECT_FLOAT_EQ((*arr)(0, 1), 1.5); + EXPECT_FLOAT_EQ((*arr)(1, 0), 2.0); + EXPECT_FLOAT_EQ((*arr)(1, 1), 2.5); + EXPECT_FLOAT_EQ((*arr)(2, 0), 3.0); + EXPECT_FLOAT_EQ((*arr)(2, 1), 3.5); +} + +TEST(Array2dTest, Stringification) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + const string expected = R"([[1, 1.5], + [2, 2.5], + [3, 3.5]])"; + EXPECT_EQ(expected, arr->ToString()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h new file mode 100644 index 0000000000..46bc1a6392 --- /dev/null +++ b/tensorflow/compiler/xla/array3d.h @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ARRAY3D_H_ +#define TENSORFLOW_COMPILER_XLA_ARRAY3D_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Simple 3D array structure. +// +// The data layout in major-to-minor order is: n1, n2, n3. +template +class Array3D { + public: + // Creates an array of dimensions n1 x n2 x n3, uninitialized values. + Array3D(const int64 n1, const int64 n2, const int64 n3) + : n1_(n1), n2_(n2), n3_(n3) { + values_.resize(n1 * n2 * n3); + } + + // Creates an array of dimensions n1 x n2 x n3, initialized to value. + Array3D(const int64 n1, const int64 n2, const int64 n3, const T value) + : Array3D(n1, n2, n3) { + Fill(value); + } + + // Creates an array from the given nested initializer list. The outer + // initializer list is the first dimension, and so on. + // + // For example {{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, + // {{9, 10}, {11, 12}, {13, 14}, {15, 16}}, + // {{17, 18}, {19, 20}, {21, 22}, {23, 24}}} + // results in an array with n1=3, n2=4, n3=2. + Array3D(std::initializer_list>> + values) + : Array3D(values.size(), values.begin()->size(), + values.begin()->begin()->size()) { + int64 n1 = 0; + for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) { + int64 n2 = 0; + for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) { + int64 n3 = 0; + for (auto n3_it = n2_it->begin(); n3_it != n2_it->end(); + ++n3_it, ++n3) { + (*this)(n1, n2, n3) = *n3_it; + } + } + } + } + + T& operator()(const int64 n1, const int64 n2, const int64 n3) { + CHECK_LT(n1, n1_); + CHECK_LT(n2, n2_); + CHECK_LT(n3, n3_); + return values_[n1 * n2_ * n3_ + n2 * n3_ + n3]; + } + + const T& operator()(const int64 n1, const int64 n2, const int64 n3) const { + CHECK_LT(n1, n1_); + CHECK_LT(n2, n2_); + CHECK_LT(n3, n3_); + return values_[n1 * n2_ * n3_ + n2 * n3_ + n3]; + } + + // Access to the array's dimensions. + int64 n1() const { return n1_; } + int64 n2() const { return n2_; } + int64 n3() const { return n3_; } + int64 num_elements() const { return values_.size(); } + + // Fills the array with the given value. + void Fill(const T& value) { + std::fill(values_.begin(), values_.end(), value); + } + + // Fills the array with sequentially increasing values. + void FillIota(const T& value) { + std::iota(values_.begin(), values_.end(), value); + } + + // Fills the array with random normal values with a mean of 0 and standard + // deviation of value. + void FillRandom(const T& value, const double mean = 0.0, + const int seed = 12345) { + std::mt19937 g(seed); + std::normal_distribution distribution(mean, + static_cast(value)); + for (auto& v : values_) { + v = static_cast(distribution(g)); + } + } + + private: + int64 n1_; + int64 n2_; + int64 n3_; + std::vector values_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ARRAY3D_H_ diff --git a/tensorflow/compiler/xla/array3d_test.cc b/tensorflow/compiler/xla/array3d_test.cc new file mode 100644 index 0000000000..fa4435dfc4 --- /dev/null +++ b/tensorflow/compiler/xla/array3d_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array3d.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(Array3dTest, UninitializedDimsCtor) { + Array3D uninit(2, 3, 4); + EXPECT_EQ(uninit.n1(), 2); + EXPECT_EQ(uninit.n2(), 3); + EXPECT_EQ(uninit.n3(), 4); + EXPECT_EQ(uninit.num_elements(), 24); +} + +TEST(Array3dTest, FillCtor) { + Array3D fullof7(2, 3, 4, 7); + + EXPECT_EQ(fullof7.n1(), 2); + EXPECT_EQ(fullof7.n2(), 3); + EXPECT_EQ(fullof7.n3(), 4); + + for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { + for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) { + for (int64 n3 = 0; n3 < fullof7.n3(); ++n3) { + EXPECT_EQ(fullof7(n1, n2, n3), 7); + } + } + } +} + +TEST(Array3dTest, InitializerListCtor) { + Array3D arr = {{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, + {{9, 10}, {11, 12}, {13, 14}, {15, 16}}, + {{17, 18}, {19, 20}, {21, 22}, {23, 24}}}; + + EXPECT_EQ(arr.n1(), 3); + EXPECT_EQ(arr.n2(), 4); + EXPECT_EQ(arr.n3(), 2); + EXPECT_EQ(arr.num_elements(), 24); + + EXPECT_EQ(arr(0, 0, 0), 1); + EXPECT_EQ(arr(0, 0, 1), 2); + EXPECT_EQ(arr(0, 1, 0), 3); + EXPECT_EQ(arr(0, 3, 1), 8); + EXPECT_EQ(arr(1, 0, 0), 9); + EXPECT_EQ(arr(1, 1, 1), 12); + EXPECT_EQ(arr(2, 0, 0), 17); + EXPECT_EQ(arr(2, 1, 1), 20); + EXPECT_EQ(arr(2, 2, 0), 21); + EXPECT_EQ(arr(2, 3, 1), 24); +} + +TEST(Array3dTest, Fill) { + Array3D fullof7(2, 3, 4, 7); + for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { + for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) { + for (int64 n3 = 0; n3 < fullof7.n3(); ++n3) { + EXPECT_EQ(fullof7(n1, n2, n3), 7); + } + } + } + + fullof7.Fill(11); + for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { + for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) { + for (int64 n3 = 0; n3 < fullof7.n3(); ++n3) { + EXPECT_EQ(fullof7(n1, n2, n3), 11); + } + } + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h new file mode 100644 index 0000000000..db51a57cf2 --- /dev/null +++ b/tensorflow/compiler/xla/array4d.h @@ -0,0 +1,272 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ARRAY4D_H_ +#define TENSORFLOW_COMPILER_XLA_ARRAY4D_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Simple 4D array structure, similar in form to Array2D, for use primarily in +// testing and describing to XLA APIs values in the 4D array structures used +// in convolutions. +// +// The data layout is, in order from major to minor: +// +// First dimension: plane, batch, n1 +// Second dimension: depth, feature, z, n2 +// Third dimension: height, y, n3 +// Fourth dimension: width, x, n4 +// +// These dimensions are referred to by various names, so that is why +// more than one name is given above. See operator() for the exact +// calculation of 1d indices from 4d indices. +template +class Array4D { + public: + // Creates a 4D array, unitialized values. + Array4D(int64 planes, int64 depth, int64 height, int64 width) + : planes_(planes), depth_(depth), height_(height), width_(width) { + values_.resize(planes * depth * height * width); + } + + // Creates a 4D array, initalized to value. + Array4D(int64 planes, int64 depth, int64 height, int64 width, T value) + : Array4D(planes, depth, height, width) { + Fill(value); + } + + // Creates a 4D array, filled with values. + // + // We need to set a default type for Container so that code like + // Array4D(1, 1, 1, 1, {1}) will work. The template cannot infer the + // initializer_list type in that case without this default. + template > + Array4D(int64 planes, int64 depth, int64 height, int64 width, + const Container& values) + : Array4D(planes, depth, height, width) { + SetValues(values); + } + + // Construct an Array4D with the given nested initializer list. + Array4D(std::initializer_list>>> + values) + : Array4D(values.size(), values.begin()->size(), + values.begin()->begin()->size(), + values.begin()->begin()->begin()->size()) { + int64 plane = 0; + for (const auto values_in_plane : values) { + DCHECK_EQ(values_in_plane.size(), depth_); + int64 depth = 0; + for (const auto values_in_depth : values_in_plane) { + DCHECK_EQ(values_in_depth.size(), height_); + int64 height = 0; + for (const auto values_in_height : values_in_depth) { + DCHECK_EQ(values_in_height.size(), width_); + int64 width = 0; + for (const auto element_value : values_in_height) { + (*this)(plane, depth, height, width) = element_value; + ++width; + } + ++height; + } + ++depth; + } + ++plane; + } + } + + T& operator()(int64 plane, int64 depth, int64 height, int64 width) { + CHECK_LT(plane, planes_); + CHECK_LT(depth, depth_); + CHECK_LT(height, height_); + CHECK_LT(width, width_); + return values_[plane * (depth_ * height_ * width_) + + depth * (height_ * width_) + height * (width_) + width]; + } + const T& operator()(int64 plane, int64 depth, int64 height, + int64 width) const { + return const_cast(this)->operator()(plane, depth, height, width); + } + + int64 width() const { return width_; } + int64 height() const { return height_; } + int64 depth() const { return depth_; } + int64 planes() const { return planes_; } + + // Numerically-named aliases for the various dimensions. This matches the + // dimension names used in array3d. + int64 n4() const { return width_; } + int64 n3() const { return height_; } + int64 n2() const { return depth_; } + int64 n1() const { return planes_; } + int64 num_elements() const { return values_.size(); } + + // Sets all the values in the array to values. + template > + void SetValues(const Container& container) { + CHECK_EQ(std::distance(std::begin(container), std::end(container)), + num_elements()); + values_.assign(std::begin(container), std::end(container)); + } + + // Fills the array with the given value. + void Fill(const T& value) { + std::fill(values_.begin(), values_.end(), value); + } + + // Fills the array with iota. + void FillIota(const T& value) { + std::iota(values_.begin(), values_.end(), value); + } + + // Fills the array with random variable with a deviation of value and a mean + // of mean. + void FillRandom(const T& value, const double mean = 0.0, + const int seed = 12345) { + std::mt19937 g(seed); + std::normal_distribution distribution(mean, + static_cast(value)); + for (auto& v : values_) { + v = static_cast(distribution(g)); + } + } + + // Fills values with the sequence i*multiplier for i=0,1,... + void FillWithMultiples(float multiplier) { + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = i * multiplier; + } + } + + // Invokes a callback with the (indices, value_ptr) for each cell in the 4D + // array. + void Each(std::function, T*)> f) { + for (int64 plane = 0; plane < planes(); ++plane) { + for (int64 depth = 0; depth < this->depth(); ++depth) { + for (int64 height = 0; height < this->height(); ++height) { + for (int64 width = 0; width < this->width(); ++width) { + auto& value = (*this)(plane, depth, height, width); + f({plane, depth, height, width}, &value); + } + } + } + } + } + + // Fills all of the {p,z} with the array provided, which specifies {y,x}. + void FillWithYX(const Array2D& value) { + CHECK_EQ(value.height(), height()); + CHECK_EQ(value.width(), width()); + for (int64 plane = 0; plane < planes(); ++plane) { + for (int64 depth = 0; depth < this->depth(); ++depth) { + for (int64 height = 0; height < this->height(); ++height) { + for (int64 width = 0; width < this->width(); ++width) { + (*this)(plane, depth, height, width) = value(height, width); + } + } + } + } + } + + // Fills all of the {x,y} with the array provided, which specifies {p,z}. + void FillWithPZ(const Array2D& value) { + CHECK_EQ(value.height(), planes()); + CHECK_EQ(value.width(), depth()); + for (int64 height = 0; height < this->height(); ++height) { + for (int64 width = 0; width < this->width(); ++width) { + for (int64 plane = 0; plane < planes(); ++plane) { + for (int64 depth = 0; depth < this->depth(); ++depth) { + (*this)(plane, depth, height, width) = value(plane, depth); + } + } + } + } + } + + // Fills each of the minor-dim matrices with a number designating which minor + // dim matrix is enclosed by the shape. + void FillWithMinorDimNum() { + LOG(INFO) << "width: " << this->width(); + LOG(INFO) << "height: " << this->height(); + LOG(INFO) << "depth: " << this->depth(); + LOG(INFO) << "planes: " << this->planes(); + for (int64 height = 0; height < this->height(); ++height) { + for (int64 width = 0; width < this->width(); ++width) { + for (int64 plane = 0; plane < planes(); ++plane) { + for (int64 depth = 0; depth < this->depth(); ++depth) { + float this_val = plane * this->depth() + depth; + (*this)(plane, depth, height, width) = this_val; + } + } + } + } + } + + // Returns a string representation of the 4D array suitable for debugging. + string ToString() const { + std::vector pieces = { + tensorflow::strings::Printf("p=%lld,z=%lld,y=%lld,x=%lld {\n", planes(), + depth(), height(), width())}; + for (int64 plane = 0; plane < planes_; ++plane) { + pieces.push_back(" {\n"); + for (int64 depth = 0; depth < depth_; ++depth) { + pieces.push_back(" {\n"); + for (int64 height = 0; height < height_; ++height) { + pieces.push_back(" {"); + for (int64 width = 0; width < width_; ++width) { + pieces.push_back(tensorflow::strings::StrCat( + (*this)(plane, depth, height, width), ", ")); + } + pieces.push_back("},\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back("}"); + return tensorflow::str_util::Join(pieces, ""); + } + + private: + int64 planes_; + int64 depth_; + int64 height_; + int64 width_; + std::vector values_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ARRAY4D_H_ diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc new file mode 100644 index 0000000000..72ada467e5 --- /dev/null +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -0,0 +1,180 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array4d.h" + +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +// Given an Array4D and a 4-tuple index, computes the linear index into the +// array idx represents. +template +int64 Array4DLinearIndex(const Array4D& arr, + tensorflow::gtl::ArraySlice idx) { + EXPECT_EQ(4, idx.size()); + return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() + + idx[0] * arr.n2() * arr.n3() * arr.n4()); +} + +TEST(Array4dTest, UninitializedDimsCtor) { + Array4D empty(2, 3, 4, 5); + EXPECT_EQ(empty.n1(), 2); + EXPECT_EQ(empty.n2(), 3); + EXPECT_EQ(empty.n3(), 4); + EXPECT_EQ(empty.n4(), 5); + EXPECT_EQ(empty.num_elements(), 120); +} + +TEST(Array4dTest, FillCtor) { + Array4D fullof7(2, 3, 4, 5, 7); + + EXPECT_EQ(fullof7.n1(), 2); + EXPECT_EQ(fullof7.n2(), 3); + EXPECT_EQ(fullof7.n3(), 4); + EXPECT_EQ(fullof7.n4(), 5); + + fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { + EXPECT_EQ(*cell, 7); + }); +} + +TEST(Array4dTest, ContainerCtor) { + // Fill an Array4D with a linear vector of [0..119] according to the default + // row-major ordering. + std::vector filler(120); + std::iota(filler.begin(), filler.end(), 0); + + Array4D arr(2, 3, 4, 5, filler); + + EXPECT_EQ(arr.n1(), 2); + EXPECT_EQ(arr.n2(), 3); + EXPECT_EQ(arr.n3(), 4); + EXPECT_EQ(arr.n4(), 5); + + arr.Each([&arr](tensorflow::gtl::ArraySlice idx, int* cell) { + EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx)); + }); +} + +TEST(Array3dTest, InitializerListCtor) { + Array4D arr = {{{{1}, {2}}, {{3}, {4}}, {{5}, {6}}, {{7}, {8}}}, + {{{9}, {10}}, {{11}, {12}}, {{13}, {14}}, {{15}, {16}}}, + {{{17}, {18}}, {{19}, {20}}, {{21}, {22}}, {{23}, {24}}}}; + + EXPECT_EQ(arr.n1(), 3); + EXPECT_EQ(arr.n2(), 4); + EXPECT_EQ(arr.n3(), 2); + EXPECT_EQ(arr.n4(), 1); + EXPECT_EQ(arr.num_elements(), 24); + + EXPECT_EQ(arr(0, 0, 0, 0), 1); + EXPECT_EQ(arr(0, 0, 1, 0), 2); + EXPECT_EQ(arr(0, 1, 0, 0), 3); + EXPECT_EQ(arr(0, 3, 1, 0), 8); + EXPECT_EQ(arr(1, 0, 0, 0), 9); + EXPECT_EQ(arr(1, 1, 1, 0), 12); + EXPECT_EQ(arr(2, 0, 0, 0), 17); + EXPECT_EQ(arr(2, 1, 1, 0), 20); + EXPECT_EQ(arr(2, 2, 0, 0), 21); + EXPECT_EQ(arr(2, 3, 1, 0), 24); +} + +TEST(Array4dTest, Fill) { + Array4D fullof7(2, 3, 4, 5, 7); + fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { + EXPECT_EQ(*cell, 7); + }); + + fullof7.Fill(11); + fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { + EXPECT_EQ(*cell, 11); + }); +} + +TEST(Array4dTest, FillWithMultiples) { + Array4D arr(2, 3, 4, 5); + arr.FillWithMultiples(2.0f); + + arr.Each([&arr](tensorflow::gtl::ArraySlice idx, float* cell) { + EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx)); + }); +} + +TEST(Array4dTest, FillRasterDimensionDepthOne) { + Array4D array(1, 1, 128, 128); + Array2D raster(128, 128); + for (int row = 0; row < 128; ++row) { + for (int col = 0; col < 128; ++col) { + raster(row, col) = row * 1000.0 + col; + } + } + + array.FillWithYX(raster); + + VLOG(1) << array.ToString(); + + EXPECT_FLOAT_EQ(raster(0, 0), array(0, 0, 0, 0)); + EXPECT_FLOAT_EQ(raster(0, 1), array(0, 0, 0, 1)); + EXPECT_FLOAT_EQ(raster(1, 0), array(0, 0, 1, 0)); + EXPECT_FLOAT_EQ(raster(1, 1), array(0, 0, 1, 1)); + EXPECT_FLOAT_EQ(raster(2, 0), array(0, 0, 2, 0)); + EXPECT_FLOAT_EQ(raster(127, 127), array(0, 0, 127, 127)); + + EXPECT_FLOAT_EQ(0, array(0, 0, 0, 0)); + EXPECT_FLOAT_EQ(1, array(0, 0, 0, 1)); + EXPECT_FLOAT_EQ(2, array(0, 0, 0, 2)); + + EXPECT_FLOAT_EQ(1001, array(0, 0, 1, 1)); + EXPECT_FLOAT_EQ(2001, array(0, 0, 2, 1)); + EXPECT_FLOAT_EQ(127000, array(0, 0, 127, 0)); + EXPECT_FLOAT_EQ(127127, array(0, 0, 127, 127)); +} + +TEST(Array4dTest, FillWithPzTestDepthOne) { + Array2D matrix(3, 2); + std::initializer_list> values = { + {-3.f, -0.1f}, {0.f, -0.1f}, {3.f, 0.2f}, + }; + int rowno = 0; + for (auto row : values) { + int colno = 0; + for (float f : row) { + matrix(rowno, colno) = f; + colno++; + } + rowno++; + } + + Array4D actual(3, 2, 1, 1); + actual.FillWithPZ(matrix); + + EXPECT_FLOAT_EQ(-3, actual(0, 0, 0, 0)); + EXPECT_FLOAT_EQ(-0.1, actual(0, 1, 0, 0)); + + EXPECT_FLOAT_EQ(0, actual(1, 0, 0, 0)); + EXPECT_FLOAT_EQ(-0.1, actual(1, 1, 0, 0)); + + EXPECT_FLOAT_EQ(3, actual(2, 0, 0, 0)); + EXPECT_FLOAT_EQ(0.2, actual(2, 1, 0, 0)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD new file mode 100644 index 0000000000..3e9dfe2a92 --- /dev/null +++ b/tensorflow/compiler/xla/client/BUILD @@ -0,0 +1,175 @@ +# Description: +# XLA client libraries. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "global_data", + srcs = ["global_data.cc"], + hdrs = ["global_data.h"], + deps = [ + "//tensorflow/compiler/xla:service_interface", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "padding", + srcs = ["padding.cc"], + hdrs = ["padding.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "padding_test", + srcs = ["padding_test.cc"], + deps = [ + ":padding", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + ":computation", + ":global_data", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:service_interface", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "local_client", + srcs = ["local_client.cc"], + hdrs = ["local_client.h"], + deps = [ + ":client", + ":computation", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:support", + ], +) + +# This target is used to instantiate the XLA service in-process and create +# a client for it. +cc_library( + name = "client_library", + srcs = ["client_library.cc"], + hdrs = ["client_library.h"], + deps = [ + ":local_client", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "computation", + srcs = ["computation.cc"], + hdrs = ["computation.h"], + deps = [ + "//tensorflow/compiler/xla:service_interface", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "computation_builder", + srcs = ["computation_builder.cc"], + hdrs = ["computation_builder.h"], + deps = [ + ":client", + ":computation", + ":global_data", + ":padding", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc new file mode 100644 index 0000000000..f70ab294c0 --- /dev/null +++ b/tensorflow/compiler/xla/client/client.cc @@ -0,0 +1,479 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/client.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +Client::Client(ServiceInterface* stub) : stub_(stub) {} + +Client::~Client() = default; + +StatusOr> Client::Transfer( + const GlobalData& data, const Shape* shape_with_layout) { + TransferToClientRequest request; + *request.mutable_data() = data.handle(); + if (shape_with_layout != nullptr) { + *request.mutable_shape_with_layout() = *shape_with_layout; + } + TransferToClientResponse response; + + VLOG(1) << "making transfer request"; + VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}"; + Status s = stub_->TransferToClient(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return FailedPrecondition( + "server provided response without a literal in " + "TransferToClient request"); + } + + return WrapUnique(response.release_literal()); +} + +Status Client::TransferInProcess(const GlobalData& data, void* destination) { + TransferToClientInProcessRequest request; + *request.mutable_data() = data.handle(); + request.set_buffer(reinterpret_cast(destination)); + TransferToClientInProcessResponse response; + + VLOG(1) << "making transfer in-process request"; + VLOG(3) << "TransferToClientInProcessRequest: {" << request.DebugString() + << "}"; + Status s = stub_->TransferToClientInProcess(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferToClientInProcessResponse: {" << response.DebugString() + << "}"; + return Status::OK(); +} + +StatusOr> Client::TransferToServer( + const Literal& literal, const DeviceHandle* device_handle) { + TransferToServerRequest request; + *request.mutable_literal() = literal; + if (device_handle) { + *request.mutable_device_handle() = *device_handle; + } + TransferToServerResponse response; + + VLOG(1) << "making transfer to server request"; + VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}"; + Status s = stub_->TransferToServer(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}"; + + if (!response.has_data()) { + return FailedPrecondition( + "server provided response without a data handle in " + "TransferToServer request"); + } + + return MakeUnique(stub_, response.data()); +} + +Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, + const DeviceHandle* device_handle) { + TransferToInfeedRequest request; + *request.mutable_literal() = literal; + if (device_handle) { + *request.mutable_device_handle() = *device_handle; + } + request.set_replica_id(replica_id); + TransferToInfeedResponse response; + + VLOG(1) << "making transfer to infeed request"; + VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}"; + Status s = stub_->TransferToInfeed(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}"; + return Status::OK(); +} + +Status Client::ResetDevice() { + ResetDeviceRequest request; + ResetDeviceResponse response; + + VLOG(1) << "making reset device request"; + VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}"; + Status s = stub_->ResetDevice(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}"; + return Status::OK(); +} + +StatusOr> Client::ExecuteAndTransfer( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout, ExecutionProfile* execution_profile, + uint64 seed) { + TF_ASSIGN_OR_RETURN(std::unique_ptr data, + Execute(computation, arguments, shape_with_output_layout, + execution_profile, seed)); + return Transfer(*data, shape_with_output_layout); +} + +StatusOr> Client::TransferToServerInProcess( + const Shape& shape, const void* buffer) { + TransferToServerInProcessRequest request; + request.set_buffer(reinterpret_cast(buffer)); + *request.mutable_shape() = shape; + TransferToServerInProcessResponse response; + + VLOG(1) << "making transfer to server in-process request"; + VLOG(3) << "TransferToServerInProcessRequest: {" << request.DebugString() + << "}"; + Status s = stub_->TransferToServerInProcess(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferToServerInProcessResponse: {" << response.DebugString() + << "}"; + + if (!response.has_data()) { + return FailedPrecondition( + "server provided response without a data handle in " + "TransferToServerInProcess request"); + } + + return MakeUnique(stub_, response.data()); +} + +StatusOr Client::LoadSnapshot(const SessionModule& module) { + LoadComputationSnapshotRequest request; + *request.mutable_module() = module; + LoadComputationSnapshotResponse response; + + Status s = stub_->LoadComputationSnapshot(&request, &response); + if (!s.ok()) { + return s; + } + + VLOG(1) << "load snapshot response: " << response.ShortDebugString(); + return Computation(stub_, response.computation()); +} + +StatusOr> Client::Execute( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout, ExecutionProfile* execution_profile, + uint64 seed) { + ExecuteRequest request; + *request.mutable_computation() = computation.handle(); + request.set_seed(seed); + for (GlobalData* argument : arguments) { + *request.add_arguments() = argument->handle(); + } + if (shape_with_output_layout != nullptr) { + *request.mutable_shape_with_output_layout() = *shape_with_output_layout; + } + + ExecuteResponse response; + VLOG(1) << "making execute request: " << request.ShortDebugString(); + Status s = stub_->Execute(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + if (execution_profile != nullptr) { + *execution_profile = response.profile(); + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN( + auto execution_stats, + ExecutionStatsAsString(computation, response.profile())); + VLOG(1) << execution_stats; + } + } + + return MakeUnique(stub_, response.output()); +} + +StatusOr>> Client::ExecuteParallel( + tensorflow::gtl::ArraySlice computations) { + ExecuteParallelRequest request; + + for (const ComputationInstance& computation : computations) { + ExecuteRequest single_request; + *single_request.mutable_computation() = computation.computation.handle(); + for (GlobalData* argument : computation.arguments) { + *single_request.add_arguments() = argument->handle(); + } + if (computation.device_handle != nullptr) { + *single_request.mutable_device_handle() = *computation.device_handle; + } + if (computation.shape_with_output_layout != nullptr) { + *single_request.mutable_shape_with_output_layout() = + *computation.shape_with_output_layout; + } + single_request.set_seed(computation.seed); + *request.add_requests() = single_request; + } + + ExecuteParallelResponse response; + VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); + tensorflow::Status s = stub_->ExecuteParallel(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector> outputs; + for (int64 i = 0; i < computations.size(); ++i) { + outputs.push_back( + MakeUnique(stub_, response.responses(i).output())); + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = response.responses(i).profile(); + } + } + + return std::move(outputs); +} + +StatusOr> Client::GetDeviceHandles( + int64 device_count) { + if (device_count < 1) { + return InvalidArgument("device_count must be greater than 0"); + } + GetDeviceHandlesRequest request; + request.set_device_count(device_count); + + GetDeviceHandlesResponse response; + VLOG(1) << "making get device request: " << request.ShortDebugString(); + tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector device_handles; + for (const DeviceHandle& device_handle : response.device_handles()) { + device_handles.push_back(device_handle); + } + + return device_handles; +} + +StatusOr Client::ExecuteAsync( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout, uint64 seed) { + ExecuteAsyncRequest request; + *request.mutable_computation() = computation.handle(); + request.set_seed(seed); + for (GlobalData* argument : arguments) { + *request.add_arguments() = argument->handle(); + } + if (shape_with_output_layout != nullptr) { + *request.mutable_shape_with_output_layout() = *shape_with_output_layout; + } + + ExecuteAsyncResponse response; + VLOG(1) << "making execute async request: " << request.ShortDebugString(); + Status s = stub_->ExecuteAsync(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + return response.execution(); +} + +StatusOr> Client::WaitForExecution( + const Computation& computation, const ExecutionHandle& execution, + ExecutionProfile* execution_profile) { + WaitForExecutionRequest request; + *request.mutable_execution() = execution; + + WaitForExecutionResponse response; + VLOG(1) << "making wait-for-execute request: " << request.ShortDebugString(); + Status s = stub_->WaitForExecution(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + if (execution_profile != nullptr) { + *execution_profile = response.profile(); + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN( + auto execution_stats, + ExecutionStatsAsString(computation, response.profile())); + VLOG(1) << execution_stats; + } + } + + return MakeUnique(stub_, response.output()); +} + +Status Client::Unregister(const GlobalData& data) { + UnregisterRequest request; + *request.mutable_data() = data.handle(); + UnregisterResponse response; + + VLOG(1) << "making unregister request"; + Status s = stub_->Unregister(&request, &response); + VLOG(1) << "done with request"; + + return s; +} + +StatusOr>> Client::DeconstructTuple( + const GlobalData& data) { + DeconstructTupleRequest request; + *request.mutable_tuple_handle() = data.handle(); + DeconstructTupleResponse response; + + VLOG(1) << "making DestructTuple request"; + Status s = stub_->DeconstructTuple(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector> handles; + for (auto& handle : response.element_handles()) { + handles.push_back(MakeUnique(stub_, handle)); + } + return std::move(handles); +} + +StatusOr Client::GetComputationStats( + const Computation& computation) const { + ComputationStatsRequest request; + *request.mutable_computation() = computation.handle(); + ComputationStatsResponse response; + + VLOG(1) << "making computation stats request"; + Status s = stub_->GetComputationStats(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + CHECK(response.has_stats()); + return response.stats(); +} + +StatusOr> Client::GetComputationShape( + const Computation& computation) { + GetComputationShapeRequest request; + *request.mutable_computation() = computation.handle(); + GetComputationShapeResponse response; + + VLOG(1) << "making get-computation-shape request"; + Status s = stub_->GetComputationShape(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + return WrapUnique(response.release_program_shape()); +} + +StatusOr Client::GetShape(const GlobalData& data) { + GetShapeRequest request; + *request.mutable_data() = data.handle(); + GetShapeResponse response; + + VLOG(1) << "making get shape request"; + Status s = stub_->GetShape(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + return response.shape(); +} + +StatusOr Client::ExecutionStatsAsString( + const Computation& computation, const ExecutionProfile& profile) { + TF_ASSIGN_OR_RETURN(auto computation_stats, GetComputationStats(computation)); + int64 total_flops = + computation_stats.flop_count() + computation_stats.transcendental_count(); + if (profile.compute_time_ns() > 0) { + int64 nanoseconds = profile.compute_time_ns(); + int64 cycle_count = profile.compute_cycle_count(); + double gflops = total_flops / nanoseconds; + return tensorflow::strings::StrCat( + "[Execution Statistics] flop count: ", computation_stats.flop_count(), + ", transcendental count: ", computation_stats.transcendental_count(), + ", compute execution time: ", nanoseconds, " nsec", + ", compute cycles: ", cycle_count, ", performance: ", gflops, + "gflop/s"); + } + return string("[Execution Statistics] not available."); +} + +StatusOr Client::CreateChannelHandle() { + CreateChannelHandleRequest request; + CreateChannelHandleResponse response; + + VLOG(1) << "making create channel handle request"; + Status s = stub_->CreateChannelHandle(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + return response.channel(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h new file mode 100644 index 0000000000..b5beeb36b1 --- /dev/null +++ b/tensorflow/compiler/xla/client/client.h @@ -0,0 +1,202 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// XLA service's client object -- wraps the service with convenience and +// lifetime-oriented methods. +class Client { + public: + explicit Client(ServiceInterface* stub); + virtual ~Client(); + + // Executes the computation with the given arguments and returns the global + // data that was produced from the execution. + // * If shape_with_output_layout is not nullptr this points to a shape with a + // layout to use as a hint when storing the output of the computation. + // Subsequent transfers of this output array to the client may be faster + // when using this layout. + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + // * If seed is set then that seed is used for the graph execution. + StatusOr> Execute( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr, + ExecutionProfile* execution_profile = nullptr, uint64 seed = 0); + + // A struct to represent a computation instance to be executed. + // * If device_handle is not nullptr, the computation is executed on a device + // associated with the handle. Otherwise, a device is chosen by the service. + // * If shapes_with_output_layout is not nullptr, the given shape and its + // layout is used as a hint when storing the output of the computation. + // * If execution_profile is not nullptr, the pointed-to ExecutionProfile will + // be filled with profile data from the execution of the computation. + // * seed is for the random number generator used in the computation. + struct ComputationInstance { + const Computation& computation; + std::vector arguments; + const DeviceHandle* device_handle; + const Shape* shape_with_output_layout; + ExecutionProfile* execution_profile; + uint64 seed; + }; + + // Executes a list ComputationInstances and returns global data produced from + // each computation. + StatusOr>> ExecuteParallel( + tensorflow::gtl::ArraySlice computations); + + // Requests device_count device handles available on the target. The returned + // device handles are used to specify the devices to execute the computations + // (see ExecuteParallel) or to transfer data (see TransferToServer or + // TransferToInfeed). + StatusOr> GetDeviceHandles(int64 device_count); + + // Executes the given computation as above Execute(), but launches the + // computation asynchronously and returns before the execution is complete. + // Returns an ExecutionHandle that represents the launched execution, which is + // used to call WaitForExecution() to wait for the execution's completion. + StatusOr ExecuteAsync( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr, uint64 seed = 0); + + // Waits until the given asynchronously launched execution of the computation + // is complete and returns the execution result. Once this is called, the + // given execution handle is no longer valid. If execution_profile is not + // nullptr then the pointed-to ExecutionProfile will be filled with profile + // data from the execution. + StatusOr> WaitForExecution( + const Computation& computation, const ExecutionHandle& execution, + ExecutionProfile* execution_profile = nullptr); + + // Transfer the global data provided to this client process, which is + // returned in the provided literal. Use sparingly to avoid transfer + // overheads. + // + // If shape_with_layout is not nullptr, it points to a shape whose layout will + // be the layout of the returned literal. + StatusOr> Transfer( + const GlobalData& data, const Shape* shape_with_layout = nullptr); + + // Transfer the given literal to the server. This allocates memory on the + // device and copies the literal's contents over. Returns a global data handle + // that can be used to refer to this value from the client. + // + // If device_handle is not nullptr, data is transferred to the associated + // device (and its replicas if replication is enabled). Otherwise, data is + // transferred to the default device (and its replicas). + StatusOr> TransferToServer( + const Literal& literal, const DeviceHandle* device_handle = nullptr); + + // Transfer the given literal to the Infeed interface of the device. + // + // device_handle and replica_id together specify a particular device; a device + // assigned for the given replica_id among the replicas that the given device + // handle belongs to. + Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, + const DeviceHandle* device_handle = nullptr); + + // Resets the device, clearing all existing state on the device. + Status ResetDevice(); + + // Executes the computation with the given arguments and transfers the result + // to the client as a literal. Parameters are defined the same as for + // Execute() and Transfer(). + StatusOr> ExecuteAndTransfer( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr, + ExecutionProfile* execution_profile = nullptr, uint64 seed = 0); + + // Unregister the memory for the given GlobalData on the device. + Status Unregister(const GlobalData& data); + + // Returns a vector of global data handles that point to the tuple elements. + StatusOr>> DeconstructTuple( + const GlobalData& computation); + + // Retrieves the statistics of the given computation. + StatusOr GetComputationStats( + const Computation& computation) const; + + // Returns the Shape of the given array specified by 'data'. The shape + // includes the Layout of the array as it is stored on the service. The layout + // information is useful for calling TransferInProcess. + StatusOr GetShape(const GlobalData& data); + + // As above, but returns the shape of the provided computation (parameter + // types/names and return type). + StatusOr> GetComputationShape( + const Computation& computation); + + // Creates a channel handle that can be used to transfer data between + // two computations via a pair of Send and Recv instructions. + StatusOr CreateChannelHandle(); + + // If the service is running in the same process as the client then the + // following "InProcess" transfer methods may be used. These methods enable + // more efficient transfer of arrays to and from the service. + + // Transfer array from the service into the given buffer. The buffer must be + // large enough to hold the array. The array is copied verbatim (memcpy) from + // the service. The method GetShape should be called ahead of time + // to get the shape and layout of the array as it is stored in the + // service. The shape and layout can be used to determine how large the buffer + // needs to be. + Status TransferInProcess(const GlobalData& data, void* destination); + + // Transfer array to the service from the given buffer with the given shape + // and layout. The service creates an internal copy of the data so the client + // can free the buffer when this method returns. + StatusOr> TransferToServerInProcess( + const Shape& shape, const void* buffer); + + StatusOr LoadSnapshot(const SessionModule& module); + + ServiceInterface* stub() { return stub_; } + + private: + // Returns the execution statistics (e.g., gflop/s) as a string from the + // ExecutionProfile returned from an execution of the computation. + StatusOr ExecutionStatsAsString(const Computation& computation, + const ExecutionProfile& profile); + + ServiceInterface* stub_; // Stub that this client is connected on. + + TF_DISALLOW_COPY_AND_ASSIGN(Client); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_ diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc new file mode 100644 index 0000000000..93437023bc --- /dev/null +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/client_library.h" + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +LocalClientOptions& LocalClientOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* LocalClientOptions::platform() const { + return platform_; +} + +LocalClientOptions& LocalClientOptions::set_number_of_replicas( + int number_of_replicas) { + number_of_replicas_ = number_of_replicas; + return *this; +} + +int LocalClientOptions::number_of_replicas() const { + return number_of_replicas_; +} + +/* static */ ClientLibrary& ClientLibrary::Singleton() { + static ClientLibrary* c = new ClientLibrary; + return *c; +} + +ClientLibrary::ClientLibrary() = default; +ClientLibrary::~ClientLibrary() = default; + +/* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( + perftools::gputools::Platform* platform) { + LocalClientOptions default_options; + default_options.set_platform(platform); + return GetOrCreateLocalClient(default_options); +} + +/* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( + const LocalClientOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + int replica_count = options.number_of_replicas(); + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + auto it = client_library.instances_.find(platform->id()); + if (it != client_library.instances_.end()) { + return it->second->client.get(); + } + + ServiceOptions service_options; + service_options.set_platform(platform); + service_options.set_number_of_replicas(replica_count); + + std::unique_ptr instance = MakeUnique(); + TF_ASSIGN_OR_RETURN(instance->service, + LocalService::NewService(service_options)); + instance->client = MakeUnique(instance->service.get()); + LocalClient* cl = instance->client.get(); + + client_library.instances_.insert( + std::make_pair(platform->id(), std::move(instance))); + return cl; +} + +/* static */ LocalClient* ClientLibrary::LocalClientOrDie() { + auto client_status = GetOrCreateLocalClient(); + TF_CHECK_OK(client_status.status()); + return client_status.ValueOrDie(); +} + +/* static */ LocalService* ClientLibrary::GetXlaService( + perftools::gputools::Platform* platform) { + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + auto it = client_library.instances_.find(platform->id()); + CHECK(it != client_library.instances_.end()); + return it->second->service.get(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h new file mode 100644 index 0000000000..2bc319f933 --- /dev/null +++ b/tensorflow/compiler/xla/client/client_library.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The "client library" instantiates a local (in-process) XLA service for +// use by this process, and connects to it with a singleton XLA local +// client. ClientLibrary::GetOrCreateLocalClient will spawn a local service, +// and return a client that's connected to it and ready to run XLA +// computations. +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// Options to configure the local client when it is created. +class LocalClientOptions { + public: + // Set the platform backing the service, or nullptr for the default platform. + LocalClientOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // Set the number of replicas to use when compiling replicated + // programs. The default is -1 meaning that the value is read from + // the xla_replicas flag. + LocalClientOptions& set_number_of_replicas(int number_of_replicas); + int number_of_replicas() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int number_of_replicas_ = -1; +}; + +class ClientLibrary { + public: + // Singleton constructor-or-accessor -- returns a client for the application + // to issue XLA commands on. Arguments: + // + // platform : The platform the underlying XLA service should target. If + // null then default platform is used. + static StatusOr GetOrCreateLocalClient( + perftools::gputools::Platform* platform = nullptr); + static StatusOr GetOrCreateLocalClient( + const LocalClientOptions& options); + + // Convenience "or-die" wrapper around the above which returns the existing + // client library or creates one with default platform and allocator. + static LocalClient* LocalClientOrDie(); + + // Returns the service from the service thread. Only used in unit tests to + // access user computations from client. + static LocalService* GetXlaService(perftools::gputools::Platform* platform); + + private: + // Returns the singleton instance of ClientLibrary. + static ClientLibrary& Singleton(); + + ClientLibrary(); + ~ClientLibrary(); + + struct LocalInstance { + // Service that is wrapped by the singleton client object. + std::unique_ptr service; + // Singleton client object. + std::unique_ptr client; + }; + + tensorflow::mutex service_mutex_; // Guards the singleton creation state. + std::unordered_map> + instances_ GUARDED_BY(service_mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc new file mode 100644 index 0000000000..cd7d8df58b --- /dev/null +++ b/tensorflow/compiler/xla/client/computation.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +Computation::Computation() : parent_(nullptr) {} + +Computation::Computation(ServiceInterface* parent, + const ComputationHandle& handle) + : handle_(handle), parent_(parent) {} + +Computation::Computation(Computation&& computation) + : handle_(computation.handle_), parent_(computation.parent_) { + computation.ResetWithoutFreeing(); +} + +void Computation::Reset() { + // TODO(leary) deallocate any owned computation. + ResetWithoutFreeing(); +} + +StatusOr> Computation::Snapshot() const { + SnapshotComputationRequest request; + *request.mutable_computation() = handle_; + SnapshotComputationResponse response; + + TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); + + return WrapUnique(response.release_module()); +} + +Computation::~Computation() { Reset(); } + +Computation& Computation::operator=(Computation&& computation) { + if (&computation != this) { + Reset(); + handle_ = computation.handle_; + parent_ = computation.parent_; + computation.ResetWithoutFreeing(); + } + return *this; +} + +void Computation::ResetWithoutFreeing() { + handle_.Clear(); + parent_ = nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h new file mode 100644 index 0000000000..b595172486 --- /dev/null +++ b/tensorflow/compiler/xla/client/computation.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ + +#include + +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Wraps a ComputationHandle protobuf with a lifetime. Computation is +// movable and not copyable to capture the same kind of unique +// ownership that std::unique_ptr represents. +class Computation { + public: + // Creates a null Computation. + Computation(); + + // parent: stub for the service on which we will deallocate the computation + // when it is no longer needed. + // handle: the computation handle protobuf from the service. + Computation(ServiceInterface* parent, const ComputationHandle& handle); + + Computation(Computation&& computation); + + // Deallocates the computation. + ~Computation(); + + Computation& operator=(Computation&& computation); + + // Returns the underlying handle. + const ComputationHandle& handle() const { return handle_; } + + // Sets handle to a null state and clears any owned computation. + void Reset(); + + // Requests that we snapshot the computation into a serializable protocol + // buffer form. + StatusOr> Snapshot() const; + + // Returns true if this object is a null Computation. + bool IsNull() const { return parent_ == nullptr; } + + private: + void ResetWithoutFreeing(); + + ComputationHandle handle_; // Handle that is wrapped by this class. + + // Stub that the handle is deallocated on when this object's lifetime ends. + ServiceInterface* parent_; + + TF_DISALLOW_COPY_AND_ASSIGN(Computation); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc new file mode 100644 index 0000000000..2b8b0b6ae5 --- /dev/null +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -0,0 +1,1539 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { + +ComputationDataHandle ComputationBuilder::ParseOpResponse( + const Status& status, OpResponse* response) { + VLOG(2) << "done with op request"; + + if (!status.ok()) { + NoteError(status); + return ComputationDataHandle(); + } + + if (response->output().handle() == 0) { + NoteError(InternalError("No output handle")); + return ComputationDataHandle(); + } + return response->output(); +} + +ComputationBuilder::ComputationBuilder(Client* client, + const string& computation_name) + : name_(computation_name), first_error_(Status::OK()), client_(client) {} + +ComputationBuilder::~ComputationBuilder() {} + +void ComputationBuilder::NoteError(const Status& error) { + if (die_immediately_on_error_) { + LOG(FATAL) << "error building computation: " << error; + } + + if (first_error_.ok()) { + first_error_ = error; + first_error_backtrace_.CreateCurrent(/*skip_count=*/1); + } +} + +std::unique_ptr ComputationBuilder::CreateSubBuilder( + const string& computation_name) { + auto sub_builder = MakeUnique(client_, computation_name); + sub_builder->parent_builder_ = this; + sub_builder->die_immediately_on_error_ = die_immediately_on_error_; + return sub_builder; +} + +Status ComputationBuilder::PrepareComputation() { + if (!first_error_.ok()) { + return first_error_; + } + if (!computation_.IsNull()) { + return Status::OK(); + } + + ComputationRequest request; + request.set_name(name_); + ComputationResponse response; + + VLOG(2) << "making computation request"; + Status s = client_->stub()->Computation(&request, &response); + VLOG(2) << "done with computation request"; + + if (!s.ok()) { + NoteError(s); + return first_error_; + } + + computation_ = Computation(client_->stub(), response.computation()); + return Status::OK(); +} + +bool ComputationBuilder::MakeWindow( + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { + const auto verify_size = [&](const int64 x, const char* x_name) { + if (x == 0 || x == window_dimensions.size()) { + return true; + } else { + NoteError(InvalidArgument( + "%s", + tensorflow::strings::StrCat( + "Window has different number of window dimensions than of ", + x_name, "\nNumber of window dimensions: ", + window_dimensions.size(), "\nNumber of ", x_name, ": ", x, + "\n") + .c_str())); // + return false; + } + }; + if (!verify_size(window_strides.size(), "window strides") || + !verify_size(padding.size(), "padding entries") || + !verify_size(lhs_dilation.size(), "lhs dilation factors") || + !verify_size(rhs_dilation.size(), "rhs dilation factors")) { + return false; + } + + window->Clear(); + for (size_t i = 0; i < window_dimensions.size(); i++) { + auto dim = window->add_dimensions(); + dim->set_size(window_dimensions[i]); + if (!window_strides.empty()) { + dim->set_stride(window_strides[i]); + } else { + dim->set_stride(1); + } + if (!padding.empty()) { + dim->set_padding_low(padding[i].first); + dim->set_padding_high(padding[i].second); + } else { + dim->set_padding_low(0); + dim->set_padding_high(0); + } + if (!lhs_dilation.empty()) { + dim->set_base_dilation(lhs_dilation[i]); + } else { + dim->set_base_dilation(1); + } + if (!rhs_dilation.empty()) { + dim->set_window_dilation(rhs_dilation[i]); + } else { + dim->set_window_dilation(1); + } + } + return true; +} + +ComputationDataHandle ComputationBuilder::ConstantOp( + const PopulateLiteral& populate) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ConstantRequest request; + Literal* literal = request.mutable_literal(); + populate(literal); + VLOG(3) << "created constant: " << literal->ShortDebugString(); + OpRequest op_request; + *op_request.mutable_constant_request() = request; + *op_request.mutable_computation() = computation_.handle(); + OpResponse response; + + VLOG(2) << "making constant request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::ConstantLiteral( + const Literal& literal) { + return ConstantOp( + [literal](Literal* mutable_literal) { *mutable_literal = literal; }); +} + +ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ParameterRequest request; + *request.mutable_shape() = shape; + request.set_parameter(parameter_number); + request.set_name(name); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_parameter_request() = request; + OpResponse response; + + VLOG(2) << "making parameter request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +StatusOr> ComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + if (!first_error_.ok()) { + return first_error_; + } + + GetLocalShapeRequest request; + *request.mutable_computation() = computation_.handle(); + *request.mutable_operand() = operand; + GetLocalShapeResponse response; + + VLOG(2) << "making get-shape request"; + Status s = client_->stub()->GetLocalShape(&request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + NoteError(s); + return first_error_; + } + TF_RET_CHECK(response.has_shape()); + std::unique_ptr shape = WrapUnique(response.release_shape()); + TF_RET_CHECK(shape != nullptr); + return std::move(shape); +} + +ComputationDataHandle ComputationBuilder::CheckShape( + const ComputationDataHandle& operand, const Shape& expected_shape) { + std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); + CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) + << "want " << ShapeUtil::HumanString(expected_shape) << " got " + << ShapeUtil::HumanString(*actual_shape); + return operand; +} + +void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { + std::unique_ptr lhs_shape = GetShape(lhs).ConsumeValueOrDie(); + std::unique_ptr rhs_shape = GetShape(rhs).ConsumeValueOrDie(); + VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals " + << ShapeUtil::HumanString(*rhs_shape); + CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape)) + << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs " + << ShapeUtil::HumanString(*rhs_shape); +} + +ComputationDataHandle ComputationBuilder::Slice( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + SliceRequest request; + *request.mutable_operand() = operand; + for (int64 index : start_indices) { + request.add_start_indices(index); + } + for (int64 index : limit_indices) { + request.add_limit_indices(index); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_slice_request() = request; + OpResponse response; + + VLOG(2) << "making slice request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + DynamicSliceRequest request; + *request.mutable_operand() = operand; + *request.mutable_start_indices() = start_indices; + for (int64 index : slice_sizes) { + request.add_slice_sizes(index); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_dynamic_slice_request() = request; + OpResponse response; + + VLOG(2) << "making dynamic slice request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + DynamicUpdateSliceRequest request; + *request.mutable_operand() = operand; + *request.mutable_update() = update; + *request.mutable_start_indices() = start_indices; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_dynamic_update_slice_request() = request; + OpResponse response; + + VLOG(2) << "making dynamic update slice request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ConcatenateRequest request; + for (const ComputationDataHandle& operand : operands) { + *request.add_operands() = operand; + } + request.set_dimension(dimension); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_concatenate_request() = request; + OpResponse response; + + VLOG(2) << "making concatenate request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + BroadcastRequest request; + *request.mutable_operand() = operand; + for (int64 size : broadcast_sizes) { + request.add_broadcast_sizes(size); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_broadcast_request() = request; + OpResponse response; + + VLOG(2) << "making broadcast request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Pad( + const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + PadRequest request; + *request.mutable_operand() = operand; + *request.mutable_padding_value() = padding_value; + *request.mutable_padding_config() = padding_config; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_pad_request() = request; + OpResponse response; + + VLOG(2) << "making pad request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Reshape( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReshapeRequest request; + *request.mutable_operand() = operand; + for (int64 dimension : dimensions) { + request.add_dimensions(dimension); + } + for (int64 new_size : new_sizes) { + request.add_new_sizes(new_size); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reshape_request() = request; + OpResponse response; + + VLOG(2) << "making reshape request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Reshape( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice new_sizes) { + if (!first_error_.ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape = GetShape(operand); + if (!shape.ok()) { + // Just early return with the existing error status. + first_error_ = shape.status(); + return ComputationDataHandle(); + } + std::vector dimensions(shape.ValueOrDie()->dimensions().size()); + std::iota(dimensions.begin(), dimensions.end(), 0); + return Reshape(operand, dimensions, new_sizes); +} + +ComputationDataHandle ComputationBuilder::Collapse( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dims_to_collapse) { + if (!first_error_.ok()) { + return ComputationDataHandle(); + } + + // Don't support out-of-order collapse here. + // Checks that the collapsed dimensions are in order and consecutive. + for (int i = 1; i < dims_to_collapse.size(); ++i) { + if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { + NoteError(InvalidArgument( + "Collapsed dimensions are not in order and consecutive.")); + return ComputationDataHandle(); + } + } + + // Create a new sizes vector from the old shape, replacing the collapsed + // dimensions by the product of their sizes. + StatusOr> shape_or_status = GetShape(operand); + if (!shape_or_status.ok()) { + // Just early return with the existing error status. + first_error_ = shape_or_status.status(); + return ComputationDataHandle(); + } + std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); + + std::vector new_sizes; + for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { + if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { + new_sizes.push_back(original_shape->dimensions(i)); + } else { + new_sizes.back() *= original_shape->dimensions(i); + } + } + + return Reshape(operand, new_sizes); +} + +void ComputationBuilder::Trace(const string& tag, + const ComputationDataHandle& operand) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return; + } + + TraceRequest request; + request.set_tag(tag); + *request.mutable_operand() = operand; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_trace_request() = request; + OpResponse response; + + VLOG(2) << "making trace request"; + Status s = client_->stub()->Op(&op_request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + NoteError(s); + } +} + +ComputationDataHandle ComputationBuilder::Select( + const ComputationDataHandle& pred, const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false) { + return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); +} + +ComputationDataHandle ComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + VariadicOpRequest request; + request.set_varop(VAROP_TUPLE); + for (const ComputationDataHandle& operand : elements) { + *request.add_operands() = operand; + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_variadic_op_request() = request; + OpResponse response; + + VLOG(2) << "making variadic op request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::GetTupleElement( + const ComputationDataHandle& tuple_data, int64 index) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + GetTupleElementRequest request; + *request.mutable_operand() = tuple_data; + request.set_index(index); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_get_tuple_element_request() = request; + OpResponse response; + + VLOG(2) << "making get tuple element op request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Eq( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Ne( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Ge( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Gt( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Le( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Lt( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Dot( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { + return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); +} + +ComputationDataHandle ComputationBuilder::Conv( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + return ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); +} + +ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + return ConvGeneral(lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); +} + +bool ComputationBuilder::VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) { + if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + NoteError( + InvalidArgument("Convolution arguments must have same number of " + "dimensions. Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str())); + return false; + } + int num_dims = ShapeUtil::Rank(lhs_shape); + if (num_dims < 3) { + NoteError(InvalidArgument( + "Convolution expects argument arrays with >= 3 dimensions. " + "Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str())); + return false; + } + int num_spatial_dims = num_dims - 2; + + const auto check_spatial_dimensions = [&]( + const char* const field_name, + const tensorflow::protobuf::RepeatedField& + numbers) { + if (numbers.size() != num_spatial_dims) { + NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", + num_spatial_dims, field_name, numbers.size())); + return false; + } + for (int i = 0; i < numbers.size(); ++i) { + if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { + NoteError(InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + field_name, i, numbers.Get(i))); + return false; + } + } + return true; + }; + return check_spatial_dimensions("spatial_dimensions", + dimension_numbers.spatial_dimensions()) && + check_spatial_dimensions( + "kernel_spatial_dimensions", + dimension_numbers.kernel_spatial_dimensions()); +} + +ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + first_error_ = lhs_shape_or_status.status(); + return ComputationDataHandle(); + } + + StatusOr> rhs_shape_or_status = GetShape(rhs); + if (!rhs_shape_or_status.ok()) { + first_error_ = rhs_shape_or_status.status(); + return ComputationDataHandle(); + } + + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); + + if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { + NoteError(InternalError("failed to verify convolution")); + return ComputationDataHandle(); + } + + std::vector base_area_dimensions( + dimension_numbers.spatial_dimensions_size()); + for (int i = 0; i < base_area_dimensions.size(); ++i) { + base_area_dimensions[i] = + lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i)); + } + + std::vector window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (int i = 0; i < window_dimensions.size(); ++i) { + window_dimensions[i] = + rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + + return ConvGeneral(lhs, rhs, window_strides, + MakePadding(base_area_dimensions, window_dimensions, + window_strides, padding), + dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::ConvGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, + dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + first_error_ = lhs_shape_or_status.status(); + return ComputationDataHandle(); + } + + StatusOr> rhs_shape_or_status = GetShape(rhs); + if (!rhs_shape_or_status.ok()) { + first_error_ = rhs_shape_or_status.status(); + return ComputationDataHandle(); + } + + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); + if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { + // Error is recorded in VerifyConvolution. + return ComputationDataHandle(); + } + + std::vector window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (int i = 0; i < window_dimensions.size(); ++i) { + window_dimensions[i] = + rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + + ConvolveRequest request; + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + *request.mutable_dimension_numbers() = dimension_numbers; + + if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, + rhs_dilation, request.mutable_window())) { + // Error is recorded in MakeWindow. + return ComputationDataHandle(); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_convolve_request() = request; + OpResponse response; + + VLOG(2) << "making convolve request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + InfeedRequest request; + *request.mutable_shape() = shape; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_infeed_request() = request; + OpResponse response; + + VLOG(2) << "making infeed op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Call( + const Computation& computation, + tensorflow::gtl::ArraySlice operands) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + CallRequest request; + *request.mutable_to_apply() = computation.handle(); + for (const ComputationDataHandle& operand : operands) { + *request.add_operands() = operand; + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_call_request() = request; + OpResponse response; + + VLOG(2) << "making call op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::CustomCall( + const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + CustomCallRequest request; + request.set_call_target_name(call_target_name); + for (const ComputationDataHandle& operand : operands) { + *request.add_operands() = operand; + } + *request.mutable_shape() = shape; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_custom_call_request() = request; + OpResponse response; + + VLOG(2) << "making custom call op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Add( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Sub( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Mul( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Div( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Rem( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Max( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Min( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::LogicalAnd( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_LOGICAL_AND, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::LogicalOr( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_LOGICAL_OR, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::LogicalNot( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_LOGICAL_NOT, operand); +} + +ComputationDataHandle ComputationBuilder::Abs( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_ABS, operand); +} + +ComputationDataHandle ComputationBuilder::Exp( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_EXP, operand); +} + +ComputationDataHandle ComputationBuilder::Floor( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_FLOOR, operand); +} + +ComputationDataHandle ComputationBuilder::Ceil( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_CEIL, operand); +} + +ComputationDataHandle ComputationBuilder::Log( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_LOG, operand); +} + +ComputationDataHandle ComputationBuilder::Sign( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_SIGN, operand); +} + +ComputationDataHandle ComputationBuilder::Tanh( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_TANH, operand); +} + +ComputationDataHandle ComputationBuilder::Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation) { + if (!first_error_.ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape = GetShape(operand); + if (!shape.ok()) { + // Just early return with the existing error status. + first_error_ = shape.status(); + return ComputationDataHandle(); + } + return Reshape(operand, permutation, + Permute(InversePermutation(permutation), + AsInt64Slice(shape.ValueOrDie()->dimensions()))); +} + +ComputationDataHandle ComputationBuilder::Rev( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReverseRequest request; + *request.mutable_operand() = operand; + for (int64 dimension : dimensions) { + request.add_dimensions(dimension); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reverse_request() = request; + OpResponse response; + + VLOG(2) << "making reverse op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Sort( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_SORT, operand); +} + +ComputationDataHandle ComputationBuilder::SqrtF32( + const ComputationDataHandle& operand) { + return BinaryOp(BINOP_POW, operand, ConstantR0(0.5), + /*broadcast_dimensions=*/{}); +} + +ComputationDataHandle ComputationBuilder::Pow( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { + return BinaryOp(BINOP_POW, lhs, rhs, /*broadcast_dimensions=*/{}); +} + +ComputationDataHandle ComputationBuilder::ConvertElementType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape_status = GetShape(operand); + if (!shape_status.ok()) { + // Just early return with the existing error status. + first_error_ = shape_status.status(); + return ComputationDataHandle(); + } + std::unique_ptr original = shape_status.ConsumeValueOrDie(); + + ConvertRequest request; + *request.mutable_operand() = operand; + request.set_new_element_type(new_element_type); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_convert_request() = request; + OpResponse response; + + VLOG(2) << "making convert request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::SquareF32( + const ComputationDataHandle& operand) { + return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), + /*broadcast_dimensions=*/{}); +} + +ComputationDataHandle ComputationBuilder::ReciprocalF32( + const ComputationDataHandle& operand) { + return BinaryOp(BINOP_POW, operand, ConstantR0(-1.0), + /*broadcast_dimensions=*/{}); +} + +ComputationDataHandle ComputationBuilder::Neg( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_NEGATE, operand); +} + +ComputationDataHandle ComputationBuilder::Clamp( + const ComputationDataHandle& min, const ComputationDataHandle& operand, + const ComputationDataHandle& max) { + return TernaryOp(TRIOP_CLAMP, min, operand, max); +} + +ComputationDataHandle ComputationBuilder::UnaryOp( + UnaryOperation unop, const ComputationDataHandle& operand) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + UnaryOpRequest request; + request.set_unop(unop); + *request.mutable_operand() = operand; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_unary_op_request() = request; + OpResponse response; + + VLOG(2) << "making unop request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::BinaryOp( + BinaryOperation binop, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + BinaryOpRequest request; + request.set_binop(binop); + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + for (int64 dimension : broadcast_dimensions) { + request.add_broadcast_dimensions(dimension); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_binary_op_request() = request; + OpResponse response; + + VLOG(2) << "making binop request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::RngOp( + RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + RngRequest request; + request.set_distribution(distribution); + for (const ComputationDataHandle& param : parameters) { + *request.add_parameter() = param; + } + *request.mutable_shape() = shape; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_rng_request() = request; + OpResponse response; + + VLOG(2) << "making rngop request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::TernaryOp( + TernaryOperation triop, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + TernaryOpRequest request; + request.set_triop(triop); + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + *request.mutable_ehs() = ehs; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_ternary_op_request() = request; + OpResponse response; + + VLOG(2) << "making triop request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + +Status ComputationBuilder::SetReturnValue( + const ComputationDataHandle& operand) { + if (!first_error_.ok()) { + return first_error_; + } + + SetReturnValueRequest request; + *request.mutable_computation() = computation_.handle(); + *request.mutable_operand() = operand; + + SetReturnValueResponse response; + + VLOG(2) << "making set-handle-to-execute request"; + Status s = client_->stub()->SetReturnValue(&request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + NoteError(s); + return first_error_; + } + + return Status::OK(); +} + +StatusOr ComputationBuilder::IsConstant( + const ComputationDataHandle& operand) { + if (!first_error_.ok()) { + return first_error_; + } + + IsConstantRequest request; + *request.mutable_computation() = computation_.handle(); + *request.mutable_operand() = operand; + IsConstantResponse response; + + VLOG(2) << "making IsConstant request"; + Status s = client_->stub()->IsConstant(&request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + NoteError(s); + return first_error_; + } + return response.is_constant(); +} + +StatusOr> ComputationBuilder::ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout) { + if (!first_error_.ok()) { + return first_error_; + } + + ComputeConstantRequest request; + *request.mutable_computation() = computation_.handle(); + *request.mutable_operand() = operand; + if (output_layout != nullptr) { + *request.mutable_output_layout() = *output_layout; + } + + ComputeConstantResponse response; + + VLOG(2) << "making compute-constant request"; + Status s = client_->stub()->ComputeConstant(&request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + NoteError(s); + return first_error_; + } + + TF_RET_CHECK(response.output().handle() != 0); + return MakeUnique(client_->stub(), response.output()); +} + +ComputationDataHandle ComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, + const Computation& computation, + tensorflow::gtl::ArraySlice static_operands) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + MapRequest request; + for (const ComputationDataHandle& operand : operands) { + *request.add_operands() = operand; + } + *request.mutable_to_apply() = computation.handle(); + for (const ComputationDataHandle& sop : static_operands) { + *request.add_static_operands() = sop; + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_map_request() = request; + OpResponse response; + + VLOG(2) << "making Map request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::RngNormal( + const ComputationDataHandle& mu, const ComputationDataHandle& sigma, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); +} + +ComputationDataHandle ComputationBuilder::RngUniform( + const ComputationDataHandle& a, const ComputationDataHandle& b, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); +} + +ComputationDataHandle ComputationBuilder::RngBernoulli( + const ComputationDataHandle& mean, const Shape& shape) { + return RngOp(RandomDistribution::RNG_BERNOULLI, {mean}, shape); +} + +ComputationDataHandle ComputationBuilder::While( + const Computation& condition, const Computation& body, + const ComputationDataHandle& init) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + WhileRequest request; + *request.mutable_condition() = condition.handle(); + *request.mutable_body() = body.handle(); + *request.mutable_init() = init; + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_while_request() = request; + OpResponse response; + + VLOG(2) << "making while request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, const Computation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReduceRequest request; + *request.mutable_operand() = operand; + *request.mutable_init_value() = init_value; + for (int64 dimension : dimensions_to_reduce) { + request.add_dimensions(dimension); + } + *request.mutable_to_apply() = computation.handle(); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reduce_request() = request; + OpResponse response; + + VLOG(2) << "making reduce request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::ReduceWindow( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, const Computation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + if (!first_error_.ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape = GetShape(operand); + if (!shape.ok()) { + // Just early return with the existing error status. + first_error_ = shape.status(); + return ComputationDataHandle(); + } + + return ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), + window_dimensions, window_strides, padding)); +} + +ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, const Computation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReduceWindowRequest request; + *request.mutable_operand() = operand; + *request.mutable_to_apply() = computation.handle(); + *request.mutable_init_value() = init_value; + + if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, + request.mutable_window())) { + NoteError(InternalError("failed to make window")); + return ComputationDataHandle(); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reduce_window_request() = request; + OpResponse response; + + VLOG(2) << "making reduce-window request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::CrossReplicaSum( + const ComputationDataHandle& operand) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + CrossReplicaSumRequest request; + *request.mutable_operand() = operand; + OpRequest op_request; + *op_request.mutable_cross_replica_sum_request() = request; + *op_request.mutable_computation() = computation_.handle(); + OpResponse response; + + VLOG(2) << "making cross-replica-sum request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::SelectAndScatter( + const ComputationDataHandle& operand, const Computation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const Computation& scatter) { + if (!first_error_.ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape = GetShape(operand); + if (!shape.ok()) { + // Just early return with the existing error status. + first_error_ = shape.status(); + return ComputationDataHandle(); + } + return SelectAndScatterWithGeneralPadding( + operand, select, window_dimensions, window_strides, + MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), + window_dimensions, window_strides, padding), + source, init_value, scatter); +} + +ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( + const ComputationDataHandle& operand, const Computation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const Computation& scatter) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + SelectAndScatterRequest request; + *request.mutable_operand() = operand; + *request.mutable_select() = select.handle(); + *request.mutable_source() = source; + *request.mutable_init_value() = init_value; + *request.mutable_scatter() = scatter.handle(); + + if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, + request.mutable_window())) { + NoteError(InternalError("failed to make window")); + return ComputationDataHandle(); + } + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_select_and_scatter_request() = request; + OpResponse response; + + VLOG(2) << "making select-and-scatter request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +void ComputationBuilder::Send(const ComputationDataHandle& operand, + const ChannelHandle& handle) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return; + } + + SendRequest request; + *request.mutable_operand() = operand; + *request.mutable_channel_handle() = handle; + OpRequest op_request; + *op_request.mutable_send_request() = request; + *op_request.mutable_computation() = computation_.handle(); + OpResponse response; + + VLOG(2) << "making send request"; + tensorflow::Status s = client_->stub()->Op(&op_request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + NoteError(s); + return; + } +} + +ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, + const ChannelHandle& handle) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + RecvRequest request; + *request.mutable_shape() = shape; + *request.mutable_channel_handle() = handle; + OpRequest op_request; + *op_request.mutable_recv_request() = request; + *op_request.mutable_computation() = computation_.handle(); + OpResponse response; + + VLOG(2) << "making recv request"; + tensorflow::Status s = client_->stub()->Op(&op_request, &response); + VLOG(2) << "done with request"; + + return ParseOpResponse(s, &response); +} + +Computation ComputationBuilder::BuildAndNoteError() { + DCHECK(parent_builder_ != nullptr); + auto build_status = Build(); + if (!build_status.ok()) { + parent_builder_->NoteError( + AddStatus(build_status.status(), + tensorflow::strings::StrCat("error from: ", name_))); + return Computation(); + } + return build_status.ConsumeValueOrDie(); +} + +StatusOr ComputationBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + if (computation_.IsNull()) { + return FailedPrecondition("no computation was built"); + } + + return {std::move(computation_)}; +} + +/* static */ ConvolutionDimensionNumbers +ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_batch_dimension(kConvBatchDimension); + dimension_numbers.set_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_kernel_output_feature_dimension( + kConvKernelOutputDimension); + dimension_numbers.set_kernel_input_feature_dimension( + kConvKernelInputDimension); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_spatial_dimensions(i + 2); + dimension_numbers.add_kernel_spatial_dimensions(i + 2); + } + return dimension_numbers; +} + +/* static */ StatusOr +ComputationBuilder::CreateConvDimensionNumbers( + int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 kernel_output_feature, int64 kernel_input_feature, + int64 kernel_first_spatial, int64 kernel_second_spatial) { + if (std::set({batch, feature, first_spatial, second_spatial}).size() != + 4) { + return FailedPrecondition( + "dimension numbers for the input are not unique: (%lld, %lld, %lld, " + "%lld)", + batch, feature, first_spatial, second_spatial); + } + if (std::set({kernel_output_feature, kernel_input_feature, + kernel_first_spatial, kernel_second_spatial}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " + "%lld)", + kernel_output_feature, kernel_input_feature, kernel_first_spatial, + kernel_second_spatial); + } + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_batch_dimension(batch); + dimension_numbers.set_feature_dimension(feature); + dimension_numbers.add_spatial_dimensions(first_spatial); + dimension_numbers.add_spatial_dimensions(second_spatial); + dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); + dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); + dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); + dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + return dimension_numbers; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h new file mode 100644 index 0000000000..a74257eae3 --- /dev/null +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -0,0 +1,783 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Wraps an XLA client with a convenient interface for building up +// computations. Any errors encountered in building up the computation are +// deferred from being handled until Build() is called. +// +// Thread-compatible. +class ComputationBuilder { + public: + // client: client in which to build the computation. + // computation_name: name to use for the built computation. + ComputationBuilder(Client* client, const string& computation_name); + + ~ComputationBuilder(); + + // Returns the client the builder was initialized with. + Client* client() { return client_; } + + // Returns the computation name. + const string& name() { return name_; } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Enqueues a "retrieve parameter value" instruction for a parameter that was + // passed to the computation. + ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + // Retrieves the (inferred) shape of the operand in the computation. + StatusOr> GetShape( + const ComputationDataHandle& operand); + + // Checks that the operand has the given expected shape. Returns the operand + // if yes, fails with a CHECK error if no. + ComputationDataHandle CheckShape(const ComputationDataHandle& operand, + const Shape& expected_shape); + + // Checks that the lhs and rhs results have the same shape. + void CheckSameShape(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + // Enqueues a constant with the value of the given literal onto the + // computation. + ComputationDataHandle ConstantLiteral(const Literal& literal); + + // Enqueues a constant onto the computation. Methods are templated on the + // native host type (NativeT) which corresponds to a specific XLA + // PrimitiveType as given in the following table: + // + // Native Type PrimitiveType + // ----------------------------- + // bool PRED + // int32 S32 + // int64 S64 + // uint32 U32 + // uint64 U64 + // float F32 + // double F64 + // + // Note: not all primitive types defined in xla_data.proto have a + // corresponding native type yet. + template + ComputationDataHandle ConstantR0(NativeT value); + template + ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice values); + ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); + template + ComputationDataHandle ConstantR2( + std::initializer_list> values); + template + ComputationDataHandle ConstantR2FromArray2D(const Array2D& values); + template + ComputationDataHandle ConstantR3FromArray3D(const Array3D& values); + template + ComputationDataHandle ConstantR4FromArray4D(const Array4D& values); + + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template + ComputationDataHandle ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + ComputationDataHandle Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + ComputationDataHandle Pad(const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + ComputationDataHandle Reshape(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // minor to major order, then reshapes it into the shape with the given + // dimension sizes, also from major to minor. Conceptually, this is a limited + // form of "shape casting". + ComputationDataHandle Reshape(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + ComputationDataHandle Collapse(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + ComputationDataHandle Slice(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + ComputationDataHandle DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + ComputationDataHandle DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + ComputationDataHandle ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const ComputationDataHandle& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + ComputationDataHandle Select(const ComputationDataHandle& pred, + const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + ComputationDataHandle Tuple( + tensorflow::gtl::ArraySlice elements); + + // Enqueues a tuple-element-get instruction onto the computation. + ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, + int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + ComputationDataHandle Eq( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + ComputationDataHandle Ne( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + ComputationDataHandle Ge( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + ComputationDataHandle Gt( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + ComputationDataHandle Lt( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + ComputationDataHandle Le( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + ComputationDataHandle Dot(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an + // error if either the input or the weight dimension numbers have conflicts. + static StatusOr CreateConvDimensionNumbers( + int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 kernel_output_feature, int64 kernel_input_feature, + int64 kernel_first_spatial, int64 kernel_second_spatial); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + ComputationDataHandle Conv(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + ComputationDataHandle ConvWithGeneralPadding( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + ComputationDataHandle ConvWithGeneralDimensions( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + ComputationDataHandle ConvGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + ComputationDataHandle ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an infeed instruction onto the computation, which reads data of + // the given shape from the infeed buffer of the device. + ComputationDataHandle Infeed(const Shape& shape); + + // Enqueues a call instruction onto the computation. + ComputationDataHandle Call( + const Computation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + ComputationDataHandle CustomCall( + const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues an add instruction onto the computation. + ComputationDataHandle Add( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + ComputationDataHandle Sub( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + ComputationDataHandle Mul( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + ComputationDataHandle Div( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + ComputationDataHandle Rem( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + ComputationDataHandle Max( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + ComputationDataHandle Min( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Element-wise logical operators + ComputationDataHandle LogicalAnd( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + ComputationDataHandle LogicalOr( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + ComputationDataHandle LogicalNot(const ComputationDataHandle& lhs); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + ComputationDataHandle Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, const Computation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + // Enqueues a windowed reduce instruction onto the computation. + ComputationDataHandle ReduceWindow( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, const Computation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + ComputationDataHandle ReduceWindowWithGeneralPadding( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, const Computation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Returns the sum of the operand value across all replicas. All replicas + // supply one input to the sum and all replicas receive the resulting sum. + ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + ComputationDataHandle SelectAndScatter( + const ComputationDataHandle& operand, const Computation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const Computation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + ComputationDataHandle SelectAndScatterWithGeneralPadding( + const ComputationDataHandle& operand, const Computation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const Computation& scatter); + + // Enqueues an abs instruction onto the computation. + ComputationDataHandle Abs(const ComputationDataHandle& operand); + + // Enqueues an exp instruction onto the computation. + ComputationDataHandle Exp(const ComputationDataHandle& operand); + + // Enqueues a floor instruction onto the computation. + ComputationDataHandle Floor(const ComputationDataHandle& operand); + + // Enqueues a ceil instruction onto the computation. + ComputationDataHandle Ceil(const ComputationDataHandle& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + ComputationDataHandle Log(const ComputationDataHandle& operand); + + // Enqueues a sign instruction onto the computation. + ComputationDataHandle Sign(const ComputationDataHandle& operand); + + // Enqueues a tanh instruction onto the computation. + ComputationDataHandle Tanh(const ComputationDataHandle& operand); + + // Enqueues a float32 sqrt instruction onto the computation. + // (float32 is specified as there is an implicit float32 0.5f constant + // exponent). + ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); + + // Enqueues a float32 square instruction onto the computation. + // (float32 is specified as there is an implicit float32 2.0f constant + // exponent). + ComputationDataHandle SquareF32(const ComputationDataHandle& operand); + + // Enqueues a lhs^rhs computation onto the computation. + ComputationDataHandle Pow(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + + // Enqueues a float32 reciprocal instruction onto the computation. + // (float32 is specified as there is an implicit float32 -1.0f constant + // exponent). + // + // TODO(leary) axe F32 suffix, can be determined by reflecting on the shape of + // the operand. + ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); + + // Enqueues a negate instruction onto the computation. + ComputationDataHandle Neg(const ComputationDataHandle& operand); + + // Enqueues a transpose instruction onto the computation. + ComputationDataHandle Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + ComputationDataHandle Rev(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + ComputationDataHandle Sort(const ComputationDataHandle& operand); + + // Enqueues a clamp instruction onto the computation. + ComputationDataHandle Clamp(const ComputationDataHandle& min, + const ComputationDataHandle& operand, + const ComputationDataHandle& max); + + // Enqueues a map instruction onto the computation. + ComputationDataHandle Map( + tensorflow::gtl::ArraySlice operands, + const Computation& computation, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + ComputationDataHandle RngNormal(const ComputationDataHandle& mu, + const ComputationDataHandle& sigma, + const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. + ComputationDataHandle RngUniform(const ComputationDataHandle& a, + const ComputationDataHandle& b, + const Shape& shape); + + // Enqueues a B(1, p) random number generation instruction onto the + // computation. + ComputationDataHandle RngBernoulli(const ComputationDataHandle& mean, + const Shape& shape); + + // Enqueues a while node onto the computation. + ComputationDataHandle While(const Computation& condition, + const Computation& body, + const ComputationDataHandle& init); + + // Enqueues a Send node onto the computation, to send the given operand to + // a Recv instruction that shares the same channel handle. + void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests + // whether a computation is a compile-time constant without evaluating the + // computation. + StatusOr IsConstant(const ComputationDataHandle& operand); + + // Computes the value of a constant indicated by a + // ComputationDataHandle. + // + // The handle must be from the computation currently being built - + // i.e., returned from this builder with no intervening call to + // Build(). This happens to currently work regardless of that, but + // that may stop working at any time. + // + // The handle must represent a constant value, which in this case + // means that it must not statically depend on a parameter to the + // computation that is being built. + // + // `IsConstant` can be used to test whether a computation is a compile-time + // constant without evaluation it. `ComputeConstant` only succeeds for + // computations where `IsConstant` returns true. + // + // This functionality can be useful when translating a computation + // into XLA where something that looked dynamic is required by + // XLA to be specified as a constant. E.g. the source + // computation (outside of XLA) may include a dynamic + // computation of the shape of something and ComputeConstant lets + // you determine what the value of that computation is in the case + // where the value can be determined at compile time. + // + // If output_layout is non-null, then the output of the computation + // will be stored using that layout. + StatusOr> ComputeConstant( + const ComputationDataHandle& handle, + const Layout* output_layout = nullptr); + + // Returns a new ComputationBuilder whose resultant Computation is used only + // by this ComputationBuilder. The sub-ComputationBuilder has the same + // die_immediately_on_error behavior as the parent. + std::unique_ptr CreateSubBuilder( + const string& computation_name); + + // Modifies the computation being built so that executions of it + // will return the value associated with operand, rather than the + // last expression enqueued on the ComputationBuilder. Any subsequent + // operations added to the ComputationBuilder will not have any effect unless + // SetReturnValue is called again. + Status SetReturnValue(const ComputationDataHandle& operand); + + // Builds the computation with the requested operations, or returns a non-ok + // status. + StatusOr Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent ComputationBuilder and returns an empty computation if building + // failed. This function is intended to be used where the returned + // Computation is only used by the parent ComputationBuilder and hence further + // operation on the returned Computation will simply be error'ed out if an + // error occurred while building this computation. If the built computation is + // to be used by a ComputationBuilder other than the parent ComputationBuilder + // then Build() should be used instead. + Computation BuildAndNoteError(); + + private: + using PopulateLiteral = std::function; + + // Limited checking of convolution parameters. Returns false on + // error. + bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers); + + // The parent ComputationBuilder of a sub-ComputationBuilder. The + // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. + ComputationBuilder* parent_builder_{nullptr}; + + // Helper function for creating a Window proto from user-supplied + // data. Returns true if the user-supplied data was valid. + bool MakeWindow(tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + Window* window); + + // Internal helper method that makes a request for a constant operation -- the + // provided function is used to populate the literal before sending the + // request. + ComputationDataHandle ConstantOp(const PopulateLiteral& populate); + + // Internal helper method that does the building for an arbitrary unary op. + ComputationDataHandle UnaryOp(UnaryOperation binop, + const ComputationDataHandle& operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + ComputationDataHandle BinaryOp( + BinaryOperation binop, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that does the building for an arbitrary ternary op. + ComputationDataHandle TernaryOp(TernaryOperation triop, + const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs, + const ComputationDataHandle& ehs); + + // Internal helper method that does the building for a random number generator + // of a given distribution with an explicitly specified shape. + ComputationDataHandle RngOp( + RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape); + + // Populates computation_ with a valid object or returns a failing status. + // This is used before any given operation is enqueued. + Status PrepareComputation(); + + // Helper function for parsing a method response and either returning the + // output computation data handle (on success) or a vacuous computation data + // handle (on failure). + ComputationDataHandle ParseOpResponse(const Status& status, + OpResponse* response); + + // Notes that the error occurred by: + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to Build()) + // * dying if die_immediately_on_error_ is true + void NoteError(const Status& error); + + string name_; // Name to use for the built computation. + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tensorflow::SavedStackTrace first_error_backtrace_; + + // The computation that operations are enqueued onto. + Computation computation_; + + // The client that the computation is created in. Not owned. + Client* client_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_{false}; + + TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); +}; + +template +ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { + return ConstantOp( + [value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR1( + tensorflow::gtl::ArraySlice values) { + return ConstantOp([&values](Literal* literal) { + LiteralUtil::PopulateR1(values, literal); + }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, + NativeT value) { + return ConstantOp([length, value](Literal* literal) { + LiteralUtil::PopulateWithValue(value, {length}, literal); + }); +} + +inline ComputationDataHandle ComputationBuilder::ConstantR1( + const tensorflow::core::Bitmap& values) { + return ConstantOp([&values](Literal* literal) { + LiteralUtil::PopulateR1(values, literal); + }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR2( + std::initializer_list> values) { + return ConstantOp([&values](Literal* literal) { + LiteralUtil::PopulateR2(values, literal); + }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( + const Array2D& values) { + return ConstantOp([&values](Literal* literal) { + LiteralUtil::PopulateR2FromArray2D(values, literal); + }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( + const Array3D& values) { + return ConstantOp([&values](Literal* literal) { + LiteralUtil::PopulateR3FromArray3D(values, literal); + }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( + const Array4D& values) { + return ConstantOp([&values](Literal* literal) { + LiteralUtil::PopulateR4FromArray4D(values, literal); + }); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc new file mode 100644 index 0000000000..be706f7d23 --- /dev/null +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/global_data.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle) + : handle_(handle), parent_(parent) {} + +GlobalData::~GlobalData() { + UnregisterRequest request; + *request.mutable_data() = handle_; + UnregisterResponse response; + VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); + tensorflow::Status s = parent_->Unregister(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + LOG(WARNING) << "failed to unregister " << handle_.ShortDebugString() + << "; continuing anyway..."; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h new file mode 100644 index 0000000000..eb11d91034 --- /dev/null +++ b/tensorflow/compiler/xla/client/global_data.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_ + +#include "tensorflow/compiler/xla/service_interface.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Wraps a GlobalDataHandle with a lifetime. +class GlobalData { + public: + // Gives ownership of the global data handle to this object. + GlobalData(ServiceInterface* parent, GlobalDataHandle handle); + + // Unregisters the wrapped handle. + ~GlobalData(); + + const GlobalDataHandle& handle() const { return handle_; } + + private: + GlobalDataHandle handle_; // Handle being wrapped. + ServiceInterface* parent_; // Service used to unregister handle_. + + TF_DISALLOW_COPY_AND_ASSIGN(GlobalData); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_ diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD new file mode 100644 index 0000000000..e185beaedd --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -0,0 +1,60 @@ +# Common computation builders for XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "arithmetic", + srcs = ["arithmetic.cc"], + hdrs = ["arithmetic.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + ], +) + +cc_library( + name = "testing", + srcs = ["testing.cc"], + hdrs = ["testing.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc new file mode 100644 index 0000000000..31efd8ee64 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +Computation CreateScalarAddComputation(PrimitiveType type, + ComputationBuilder* builder) { + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto b = builder->CreateSubBuilder("add_" + PrimitiveType_Name(type)); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + b->Add(lhs, rhs); + return b->BuildAndNoteError(); +} + +Computation CreateScalarGeComputation(PrimitiveType type, + ComputationBuilder* builder) { + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto b = builder->CreateSubBuilder("ge_" + PrimitiveType_Name(type)); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + b->Ge(lhs, rhs); + return b->BuildAndNoteError(); +} + +Computation CreateScalarMaxComputation(PrimitiveType type, + ComputationBuilder* builder) { + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto b = builder->CreateSubBuilder("max_" + PrimitiveType_Name(type)); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + b->Max(lhs, rhs); + return b->BuildAndNoteError(); +} + +Computation CreateScalarMinComputation(PrimitiveType type, + ComputationBuilder* builder) { + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto b = builder->CreateSubBuilder("min_" + PrimitiveType_Name(type)); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + b->Min(lhs, rhs); + return b->BuildAndNoteError(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h new file mode 100644 index 0000000000..57fe7d7462 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Creates a scalar add computation and returns it. +Computation CreateScalarAddComputation(PrimitiveType type, + ComputationBuilder* builder); + +// Creates a scalar ge computation and returns it. +Computation CreateScalarGeComputation(PrimitiveType type, + ComputationBuilder* builder); + +// Creates a scalar max computation and returns it. +Computation CreateScalarMaxComputation(PrimitiveType type, + ComputationBuilder* builder); + +// Creates a scalar min computation and returns it. +Computation CreateScalarMinComputation(PrimitiveType type, + ComputationBuilder* builder); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc new file mode 100644 index 0000000000..004f3815d2 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/testing.h" + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +std::unique_ptr MakeFakeDataOrDie(const Shape& shape, + Client* client) { + ComputationBuilder b( + client, + tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); + // TODO(b/26811613): Replace this when RNG is supported on all backends. + b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())), + AsInt64Slice(shape.dimensions())); + Computation computation = b.Build().ConsumeValueOrDie(); + return client->Execute(computation, /*arguments=*/{}, &shape) + .ConsumeValueOrDie(); +} + +std::vector> MakeFakeArgumentsOrDie( + const Computation& computation, Client* client) { + auto program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); + + // For every (unbound) parameter that the computation wants, we manufacture + // some arbitrary data so that we can invoke the computation. + std::vector> fake_arguments; + for (const Shape& parameter : program_shape->parameters()) { + fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); + } + + return fake_arguments; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h new file mode 100644 index 0000000000..7e640d1307 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_ + +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Generates fake data of the given shape on the device or dies. The fake data +// is created by performing a computation on the device rather than transferring +// data from the host to the device. +std::unique_ptr MakeFakeDataOrDie(const Shape& shape, + Client* client); + +// Returns vector of GlobalData handles of fake data (created using +// MakeFakeDataOrDie) that are correctly shaped arguments for the given +// computation. +std::vector> MakeFakeArgumentsOrDie( + const Computation& computation, Client* client); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc new file mode 100644 index 0000000000..148c033eaa --- /dev/null +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -0,0 +1,371 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/local_client.h" + +#include + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +ExecutableBuildOptions& ExecutableBuildOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* ExecutableBuildOptions::platform() const { + return platform_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( + int device_ordinal) { + device_ordinal_ = device_ordinal; + return *this; +} + +int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } + +ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( + const Shape& shape_with_layout) { + result_layout_set_ = true; + result_layout_ = shape_with_layout; + return *this; +} + +const Shape* ExecutableBuildOptions::result_layout() const { + return result_layout_set_ ? &result_layout_ : nullptr; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result( + bool has_hybrid_result) { + has_hybrid_result_ = has_hybrid_result; + return *this; +} + +bool ExecutableBuildOptions::has_hybrid_result() const { + return has_hybrid_result_; +} + +namespace { + +// Convenience class which holds an acquired stream from the backend and +// automatically releases it when destructed. +class StreamManager { + public: + static StatusOr> AcquireStream( + Backend* backend, int device_ordinal) { + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + backend->stream_executor(device_ordinal == -1 + ? backend->default_device_ordinal() + : device_ordinal)); + TF_ASSIGN_OR_RETURN(std::unique_ptr stream, + backend->AcquireStream(executor)); + return WrapUnique(new StreamManager(backend, std::move(stream))); + } + + ~StreamManager() { backend_->ReleaseStream(std::move(stream_)); } + + se::Stream* stream() const { return stream_.get(); } + + private: + StreamManager(Backend* backend, std::unique_ptr stream) + : backend_(backend), stream_(std::move(stream)) {} + + Backend* backend_; + std::unique_ptr stream_; +}; + +} // namespace + +LocalExecutable::LocalExecutable(std::unique_ptr executable, + Backend* backend, int device_ordinal, + const ExecutableBuildOptions& build_options) + : executable_(std::move(executable)), + backend_(backend), + build_device_ordinal_(device_ordinal), + build_options_(build_options) {} + +tensorflow::Status LocalExecutable::ValidateExecutionOptions( + const tensorflow::gtl::ArraySlice arguments, + const ExecutableRunOptions& options) { + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); + + // Check argument number, shapes, and layouts. + if (arguments.size() != computation_layout.parameter_count()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + computation_layout.parameter_count(), arguments.size()); + } + for (int i = 0; i < arguments.size(); ++i) { + if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( + arguments[i]->shape())) { + return InvalidArgument( + "argument does not match shape or layout of computation parameter " + "%d: expected %s, got %s", + i, + ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) + .c_str(), + ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + } + } + + if (options.stream() != nullptr) { + if (!options.stream()->ok()) { + return InvalidArgument("stream is uninitialized or in an error state"); + } + + // Check stream matches service platform. + const se::Platform* stream_platform = + options.stream()->parent()->platform(); + if (stream_platform != backend_->platform()) { + return InvalidArgument( + "stream is for platform %s, but service targets platform %s", + stream_platform->Name().c_str(), + backend_->platform()->Name().c_str()); + } + + // Cannot specify device_ordinal with a stream. The stream determines these + // values. + if (options.device_ordinal() != -1) { + return InvalidArgument( + "cannot set both device ordinal and stream options in " + "ExecutableRunOptions; the stream determines the device ordinal"); + } + } + + // Verify that the device the executable was built for is equivalent to the + // device it will run on. + int run_device_ordinal = options.device_ordinal() == -1 + ? backend_->default_device_ordinal() + : options.device_ordinal(); + TF_ASSIGN_OR_RETURN( + bool devices_equivalent, + backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_)); + if (!devices_equivalent) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor, + backend_->stream_executor(run_device_ordinal)); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor, + backend_->stream_executor(build_device_ordinal_)); + return InvalidArgument( + "executable is built for device %s of type \"%s\"; cannot run it on " + "device %s of type \"%s\"", + backend_->device_name(build_device_ordinal_).c_str(), + build_executor->GetDeviceDescription().name().c_str(), + backend_->device_name(run_device_ordinal).c_str(), + run_executor->GetDeviceDescription().name().c_str()); + } + + return tensorflow::Status::OK(); +} + +StatusOr> LocalExecutable::Run( + const tensorflow::gtl::ArraySlice arguments, + const ExecutableRunOptions& options) { + TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options)); + + ExecutableRunOptions actual_options = options; + std::unique_ptr acquired_stream; + if (options.stream() == nullptr) { + TF_ASSIGN_OR_RETURN( + acquired_stream, + StreamManager::AcquireStream(backend_, options.device_ordinal())); + actual_options.set_stream(acquired_stream->stream()); + } + if (options.allocator() == nullptr) { + actual_options.set_allocator(backend_->memory_allocator()); + } + + if (executable_->dumping()) { + return ExecuteAndDump(&actual_options, arguments); + } + return executable_->ExecuteOnStream(&actual_options, arguments, + /*hlo_execution_profile=*/nullptr); +} + +tensorflow::Status LocalExecutable::Run( + const tensorflow::gtl::ArraySlice arguments, + const ExecutableRunOptions& options, ShapedBuffer* result) { + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); + TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options)); + + if (!computation_layout.result_layout().MatchesLayoutInShape( + result->shape())) { + return InvalidArgument( + "result buffer does not match shape or layout of computation result: " + "expected %s, got %s", + ShapeUtil::HumanString(computation_layout.result_layout().shape()) + .c_str(), + ShapeUtil::HumanString(result->shape()).c_str()); + } + + ExecutableRunOptions actual_options = options; + std::unique_ptr acquired_stream; + if (options.stream() == nullptr) { + TF_ASSIGN_OR_RETURN( + acquired_stream, + StreamManager::AcquireStream(backend_, options.device_ordinal())); + actual_options.set_stream(acquired_stream->stream()); + } + if (options.allocator() == nullptr) { + actual_options.set_allocator(backend_->memory_allocator()); + } + + if (executable_->dumping()) { + return Unimplemented("dumping execution not supported on this path"); + } + return executable_->ExecuteOnStream(&actual_options, arguments, result, + /*hlo_execution_profile=*/nullptr); +} + +StatusOr> LocalExecutable::ExecuteAndDump( + const ExecutableRunOptions* run_options, + const tensorflow::gtl::ArraySlice arguments) { + executable_->session_module()->set_execution_platform( + backend_->platform()->Name()); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result, + executable_->ExecuteOnStream(run_options, arguments, + /*hlo_execution_profile=*/nullptr)); + TF_RETURN_IF_ERROR(RecordResult(result.get(), executable_->session_module())); + TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); + return std::move(result); +} + +tensorflow::Status LocalExecutable::RecordArguments( + const tensorflow::gtl::ArraySlice arguments, + SessionModule* session_module) { + session_module->clear_arguments(); + for (const ShapedBuffer* argument : arguments) { + TF_RETURN_IF_ERROR( + LiteralFromShapedBuffer(*argument, session_module->add_arguments())); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status LocalExecutable::RecordResult( + const ShapedBuffer* result, SessionModule* session_module) { + session_module->clear_result(); + return LiteralFromShapedBuffer(*result, session_module->mutable_result()); +} + +tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer, Literal* literal) { + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + backend_->stream_executor(shaped_buffer.device_ordinal())); + return backend_->transfer_manager()->TransferLiteralFromDevice( + executor, shaped_buffer.buffer({}), shaped_buffer.shape(), + shaped_buffer.shape(), literal); +} + +StatusOr> LocalClient::AllocateBufferOnDevice( + const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) { + TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, + local_service_->AllocateBufferOnDevice( + shape, device_ordinal, allocate_space_for_deep_copy)); + return std::unique_ptr(new GlobalData(local_service_, handle)); +} + +tensorflow::Status LocalClient::ResolveArguments( + const tensorflow::gtl::ArraySlice arguments, + int device_ordinal, + std::vector* argument_ptrs) { + return local_service_->ResolveArguments(arguments, device_ordinal, + argument_ptrs); +} + +StatusOr> LocalClient::ExecuteLocally( + const Computation& computation, + const tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options) { + return local_service_->ExecuteLocally(computation.handle(), arguments, + options); +} + +tensorflow::Status LocalClient::ExecuteLocally( + const Computation& computation, + const tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, ShapedBuffer* result) { + return local_service_->ExecuteLocally(computation.handle(), arguments, + options, result); +} + +StatusOr> LocalClient::CompileAheadOfTime( + const Computation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const Shape& result_layout, const AotCompilationOptions& options) { + return local_service_->CompileAheadOfTime( + computation.handle(), argument_layouts, result_layout, options); +} + +int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) { + llvm::Triple triple( + llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); + if (triple.isArch64Bit()) { + return 8; + } else if (triple.isArch32Bit()) { + return 4; + } else { + CHECK(triple.isArch16Bit()); + return 2; + } +} + +se::Platform* LocalClient::platform() const { + return local_service_->backend().platform(); +} + +int LocalClient::device_count() const { + return local_service_->backend().device_count(); +} + +bool LocalClient::device_ordinal_supported(int device_ordinal) const { + return local_service_->backend().device_ordinal_supported(device_ordinal); +} + +int LocalClient::default_device_ordinal() const { + return local_service_->backend().default_device_ordinal(); +} + +StatusOr> LocalClient::Compile( + const Computation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options) { + int device_ordinal = options.device_ordinal() == -1 + ? default_device_ordinal() + : options.device_ordinal(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + local_service_->CompileExecutable(computation.handle(), argument_layouts, + options.result_layout(), device_ordinal, + options.has_hybrid_result())); + return WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + device_ordinal, options)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h new file mode 100644 index 0000000000..1d6243a3b6 --- /dev/null +++ b/tensorflow/compiler/xla/client/local_client.h @@ -0,0 +1,263 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ + +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Class containing options for building an LocalExecutable with +// LocalClient::Compile. +class ExecutableBuildOptions { + public: + // If set, this is the platform to build the computation for. This must match + // the underlying platform of the service. A value of nullptr indicates the + // option has not been set. + // + // TODO(b/28616830): Support multiple platforms. + ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + ExecutableBuildOptions& set_device_ordinal(int device_ordinal); + int device_ordinal() const; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accomodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); + const Shape* result_layout() const; + + // If set, the executable will be built to output a hybrid + // ShapedBuffer with top-level tuple pointers in host memory and + // result buffers in device memory. + ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result); + bool has_hybrid_result() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int device_ordinal_ = -1; + Shape result_layout_; + bool result_layout_set_ = false; + bool has_hybrid_result_ = true; +}; + +class LocalExecutable { + public: + // Run the compiled computation with the given arguments and options and + // return the result. + StatusOr> Run( + const tensorflow::gtl::ArraySlice arguments, + const ExecutableRunOptions& options); + + // Overload which places the computation result in the given preallocated + // buffer. + tensorflow::Status Run( + const tensorflow::gtl::ArraySlice arguments, + const ExecutableRunOptions& options, ShapedBuffer* result); + + // Return the layout (contained in a shape) of the result produced by the + // computation. + const Shape& result_layout() const { + return executable_->module_config() + .entry_computation_layout() + .result_layout() + .shape(); + } + + // Return the options used to build the executable. + const ExecutableBuildOptions& build_options() const { return build_options_; } + + // Return the built executable. + Executable* executable() const { return executable_.get(); } + + private: + // Only a local client can construct these objects. + friend class LocalClient; + + // Constructor invoked by LocalClient. + LocalExecutable(std::unique_ptr executable, Backend* backend, + int device_ordinal, + const ExecutableBuildOptions& build_options); + + // Validates that the given arguments and options satisfy various constraints + // of the computation. + tensorflow::Status ValidateExecutionOptions( + const tensorflow::gtl::ArraySlice arguments, + const ExecutableRunOptions& options); + + // Records the computation in a SessionModule proto with the arguments used to + // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + StatusOr> ExecuteAndDump( + const ExecutableRunOptions* run_options, + const tensorflow::gtl::ArraySlice arguments); + + // Records the arguments used to invoke the computation in a SessionModule + // proto. + tensorflow::Status RecordArguments( + const tensorflow::gtl::ArraySlice arguments, + SessionModule* session_module); + + // Records the result of the computation in a SessionModule proto. + tensorflow::Status RecordResult(const ShapedBuffer* result, + SessionModule* session_module); + + // Copies the contents of a ShapedBuffer into a Literal proto. + tensorflow::Status LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer, + Literal* literal); + + // Compiled computation. + std::unique_ptr executable_; + + // Execution backend. + Backend* backend_; + + // The ordinal of the device which this executable was compiled for. The + // executable can run on all equivalent devices (as determined by + // Backend::devices_equivalent). + int build_device_ordinal_; + + // Options used to build the executable. + const ExecutableBuildOptions& build_options_; +}; + +// An XLA service client object for use when the client and service run in +// the same process. +class LocalClient : public Client { + public: + explicit LocalClient(LocalService* service) + : Client(service), local_service_(service) {} + + LocalClient(const LocalClient&) = delete; + void operator=(const LocalClient&) = delete; + + // For an array of arguments held on the local service, validate + // that each is placed on the specified device_ordinal, and return + // the DeviceMemoryBase corresponding to each argument. + tensorflow::Status ResolveArguments( + const tensorflow::gtl::ArraySlice arguments, + int device_ordinal, + std::vector* argument_ptrs); + + // Return a handle to a buffer large enough to hold shape, allocated + // on device_ordinal on the local service. If + // allocate_space_for_deep_copy, the buffer is large enough to hold + // all sub-buffers of a tuple shape, otherwise it is only as large + // as the top-level tuple pointer array. + StatusOr> AllocateBufferOnDevice( + const Shape& shape, int device_ordinal, + bool allocate_space_for_deep_copy); + + // Executes the given computation with the given arguments and + // options. Arguments and result are "zero-copy", and are passed as pointers + // to device memory. See LocalExecuteOptions class comments for description of + // available options. The returned ShapedBuffer includes pointer(s) to device + // memory (DeviceMemoryBase) which are the caller's responsibility to + // deallocate. The layout of the result is chosen by the XLA service and + // should not be relied upon to be a specific value. If a specific result + // layout is needed, then the layout should be set in options. + // + // The arrays of arguments with different shapes or layouts are assumed not to + // alias. + // + // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use + // Compile and run the returned LocalExecutable. + StatusOr> ExecuteLocally( + const Computation& computation, + const tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options); + + // Overload of ExecuteLocally which writes the result into the given + // ShapedBuffer "result". Result is const because the ShapedBuffer data + // structure itself is not modified, only the buffers in device memory to + // which it refers. + // + // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use + // Compile and run the returned LocalExecutable. + tensorflow::Status ExecuteLocally( + const Computation& computation, + const tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, ShapedBuffer* result); + + // Build and return a LocalExecutable object. The executable is compiled using + // the given argument layouts and options. + StatusOr> Compile( + const Computation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options); + + // Compiles the computation for ahead-of-time execution. This is intended for + // use in static compilation. The |argument_layouts| parameter is used to + // inform the compiler of the expected layout for arguments while + // |result_layout| is used to signal the layout of the result. The |options| + // parameter is used to request which target the compiler should emit code + // for. + // + // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its + // own library. + StatusOr> CompileAheadOfTime( + const Computation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const Shape& result_layout, const AotCompilationOptions& options); + + // Returns the size of a pointer in bytes for a given triple. + static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + + // Returns the platform that the underlying service targets. + perftools::gputools::Platform* platform() const; + + // Returns the number of devices on the system of the service platform + // type. Not all devices may be supported by the service (see + // device_ordinal_supported method). + int device_count() const; + + // Returns the default device ordinal that the service will run computations + // on if no device ordinal is specified in execute options. + int default_device_ordinal() const; + + // Returns whether the device with the given ordinal can be used by the + // service to execute computations. Not all devices of a particular platform + // may be usable by the service (eg, a GPU with insufficient CUDA compute + // capability). + bool device_ordinal_supported(int device_ordinal) const; + + private: + LocalService* local_service_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc new file mode 100644 index 0000000000..281fa10408 --- /dev/null +++ b/tensorflow/compiler/xla/client/padding.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/padding.h" + +#include + +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +std::vector> MakePadding( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + CHECK_EQ(input_dimensions.size(), window_dimensions.size()); + CHECK_EQ(input_dimensions.size(), window_strides.size()); + std::vector> low_high_padding; + switch (padding) { + case Padding::kValid: + low_high_padding.resize(window_dimensions.size(), {0, 0}); + return low_high_padding; + + case Padding::kSame: + for (int64 i = 0; i < input_dimensions.size(); ++i) { + int64 input_dimension = input_dimensions[i]; + int64 window_dimension = window_dimensions[i]; + int64 window_stride = window_strides[i]; + // We follow the same convention as in Tensorflow, such that + // output dimension := ceil(input_dimension / window_stride). + // See tensorflow/tensorflow/python/ops/nn.py + // for the reference. See also tensorflow/core/kernels/ops_util.cc + // for the part where we avoid negative padding using max(0, x). + // + // + // For an odd sized window dimension 2N+1 with stride 1, the middle + // element is always inside the base area, so we can see it as N + 1 + + // N elements. In the example below, we have a kernel of size + // 2*3+1=7 so that the center element is 4 with 123 to the + // left and 567 to the right. + // + // base area: ------------------------ + // kernel at left: 1234567 + // kernel at right: 1234567 + // + // We can see visually here that we need to pad the base area + // by 3 on each side: + // + // padded base area: 000------------------------000 + // + // For an even number 2N, there are two options: + // + // *** Option A + // + // We view 2N as (N - 1) + 1 + N, so for N=3 we have 12 to the + // left, 3 is the center and 456 is to the right, like this: + // + // base area: ------------------------ + // kernel at left: 123456 + // kernel at right: 123456 + // padded base area: 00------------------------000 + // + // Note how we pad by one more to the right than to the left. + // + // *** Option B + // + // We view 2N as N + 1 + (N - 1), so for N=3 we have 123 to + // the left, 4 is the center and 56 is to the right, like + // this: + // + // base area: ------------------------ + // kernel at left: 123456 + // kernel at right: 123456 + // padded base area: 000------------------------00 + // + // The choice here is arbitrary. We choose option A as this is + // what DistBelief and Tensorflow do. + // + // When the stride is greater than 1, the output size is smaller than + // the input base size. The base area is padded such that the last + // window fully fits in the padded base area, and the padding amount is + // evenly divided between the left and the right (or 1 more on the right + // if odd size padding is required). The example below shows the + // required padding when the base size is 10, the kernel size is 5, and + // the stride is 3. In this example, the output size is 4. + // + // base area: ---------- + // 1'st kernel: 12345 + // 2'nd kernel: 12345 + // 3'rd kernel: 12345 + // 4'th kernel: 12345 + // padded base area: 00----------00 + int64 output_dimension = + tensorflow::MathUtil::CeilOfRatio(input_dimension, window_stride); + int64 padding_size = + std::max((output_dimension - 1) * window_stride + + window_dimension - input_dimension, + 0); + low_high_padding.emplace_back( + tensorflow::MathUtil::FloorOfRatio(padding_size, 2ll), + tensorflow::MathUtil::CeilOfRatio(padding_size, 2ll)); + } + break; + } + + return low_high_padding; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h new file mode 100644 index 0000000000..dce2d87e8d --- /dev/null +++ b/tensorflow/compiler/xla/client/padding.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_PADDING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_PADDING_H_ + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// Describes the padding applied for a windowed operation like +// convolution, where a window is placed inside a base area. +enum class Padding { + // Make the output have the same dimensions as the base area. For + // example, for a 3x3 base area and a 2x2 window, the output will be + // 3x3, so that requires padding the 3x3 base area to 4x4. + kSame, + + // Use no padding. For example, for a 4x4 base area and a 2x2 + // window, the output will be 3x3. + kValid, +}; + +// Returns the padding needed for the base area, given the base area dimensions, +// window dimensions, strides, and the type of padding. +// +// If v is the returned vector, then for each dimension number i, +// v[i].first is the padding to the left (i.e. in the direction of +// lower indices) and v[i].second is the padding to the right (i.e. in +// the direction of higher indices). +// +// Precondition: The number of dimensions (i.e., rank) in input_dimensions, +// window_dimensions, and strides must match, which is equal to the number +// of elements in the result vector. +std::vector> MakePadding( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice strides, Padding padding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_PADDING_H_ diff --git a/tensorflow/compiler/xla/client/padding_test.cc b/tensorflow/compiler/xla/client/padding_test.cc new file mode 100644 index 0000000000..deda5ce708 --- /dev/null +++ b/tensorflow/compiler/xla/client/padding_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/padding.h" + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +// Tests MakePadding utility function for various cases. +class PaddingTest : public ::testing::Test { + protected: + // A convenience function to test padding for a single dimension. + std::pair ComputePadding(int64 input_dimension, + int64 window_dimension, + int64 window_stride, Padding padding) { + return MakePadding({input_dimension}, {window_dimension}, {window_stride}, + padding)[0]; + } +}; + +TEST_F(PaddingTest, ValidPaddingWithStrideOne) { + const auto padding = ComputePadding(10, 5, 1, Padding::kValid); + EXPECT_EQ(padding.first, 0); + EXPECT_EQ(padding.second, 0); +} + +TEST_F(PaddingTest, ValidPaddingWithStrideThree) { + const auto padding = ComputePadding(10, 5, 3, Padding::kValid); + EXPECT_EQ(padding.first, 0); + EXPECT_EQ(padding.second, 0); +} + +TEST_F(PaddingTest, SamePaddingWithOddWindow) { + const auto padding = ComputePadding(10, 7, 1, Padding::kSame); + EXPECT_EQ(padding.first, 3); + EXPECT_EQ(padding.second, 3); +} + +TEST_F(PaddingTest, SamePaddingWithEvenWindow) { + const auto padding = ComputePadding(10, 6, 1, Padding::kSame); + EXPECT_EQ(padding.first, 2); + EXPECT_EQ(padding.second, 3); +} + +TEST_F(PaddingTest, SamePaddingWithOddWindowWithStride) { + const auto padding = ComputePadding(10, 7, 3, Padding::kSame); + EXPECT_EQ(padding.first, 3); + EXPECT_EQ(padding.second, 3); +} + +TEST_F(PaddingTest, SamePaddingWithEvenWindowWithStride) { + const auto padding = ComputePadding(10, 6, 4, Padding::kSame); + EXPECT_EQ(padding.first, 2); + EXPECT_EQ(padding.second, 2); +} + +TEST_F(PaddingTest, SamePaddingForWindowSizeOne) { + const auto padding = ComputePadding(10, 1, 1, Padding::kSame); + EXPECT_EQ(padding.first, 0); + EXPECT_EQ(padding.second, 0); +} + +TEST_F(PaddingTest, SamePaddingForWindowLargerThanInput) { + const auto padding = ComputePadding(10, 20, 1, Padding::kSame); + EXPECT_EQ(padding.first, 9); + EXPECT_EQ(padding.second, 10); +} + +// This used to trigger a case with negative padding. +TEST_F(PaddingTest, NonNegativePadding) { + const auto padding = ComputePadding(4, 1, 2, Padding::kSame); + EXPECT_EQ(padding.first, 0); + EXPECT_EQ(padding.second, 0); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h new file mode 100644 index 0000000000..23a622b1ad --- /dev/null +++ b/tensorflow/compiler/xla/device_util.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities common between the client and server for working with +// StreamExecutor devices. + +#ifndef TENSORFLOW_COMPILER_XLA_DEVICE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_DEVICE_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Returns a string that represents the device in terms of platform and ordinal; +// e.g. the first CUDA device will be "cuda:0" +string DeviceIdentifier(perftools::gputools::StreamExecutor* stream_exec) { + return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", + stream_exec->device_ordinal()); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_DEVICE_UTIL_H_ diff --git a/tensorflow/compiler/xla/differential_set.h b/tensorflow/compiler/xla/differential_set.h new file mode 100644 index 0000000000..9eae24ce30 --- /dev/null +++ b/tensorflow/compiler/xla/differential_set.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_ +#define TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_ + +#include + +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// In the base case, the differential set is just a set. +// However, you can also point a differential set at another differential set to +// use as a "parent". This makes a chain of sets, which each node in the chain +// adds some number of elements to the "Contains" property. +// +// E.g. if the base set holds {1, 2}, you can create a derived set that holds +// {3}, and the derived set will tell you it contains {1, 2, 3} whereas the base +// will continue to tell you it holds only {1, 2}. +template +class DifferentialSet { + public: + // Constructs a differential set capable of holding values. It may have an + // ancestor link, which would it into a chain of sets. + explicit DifferentialSet(const DifferentialSet* parent = nullptr) + : parent_(parent) {} + + // Adds a value to be held directly by this set. + void Add(T value) { held_.insert(value); } + + // Returns whether this set holds the given value, or any ancestor in the + // chain of sets. + bool Contains(T value) const { + return held_.find(value) != held_.end() || + (parent_ != nullptr && parent_->Contains(value)); + } + + private: + // Values held directly by this node in the chain of sets. + std::set held_; + + // Parent node in the chain of sets. + const DifferentialSet* parent_; + + TF_DISALLOW_COPY_AND_ASSIGN(DifferentialSet); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_ diff --git a/tensorflow/compiler/xla/differential_set_test.cc b/tensorflow/compiler/xla/differential_set_test.cc new file mode 100644 index 0000000000..dacbbcc1ad --- /dev/null +++ b/tensorflow/compiler/xla/differential_set_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/differential_set.h" + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(DifferentialSetTest, TellsWhetherSetContainsSomethingHeld) { + DifferentialSet set; + set.Add(1); + set.Add(2); + EXPECT_FALSE(set.Contains(3)); + EXPECT_TRUE(set.Contains(1)); + EXPECT_TRUE(set.Contains(2)); + EXPECT_FALSE(set.Contains(0)); +} + +TEST(DifferentialSetTest, TellsWhetherSetContainsSomethingParentHolds) { + DifferentialSet parent; + parent.Add(1); + DifferentialSet child(&parent); + child.Add(2); + + // Test properties of the child. + EXPECT_FALSE(child.Contains(3)); + EXPECT_TRUE(child.Contains(1)); + EXPECT_TRUE(child.Contains(2)); + EXPECT_FALSE(child.Contains(0)); + + // Test properties of the parent. + EXPECT_TRUE(parent.Contains(1)); + EXPECT_FALSE(parent.Contains(2)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc new file mode 100644 index 0000000000..1c54fec97c --- /dev/null +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/executable_run_options.h" + +namespace xla { + +ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal( + int device_ordinal) { + device_ordinal_ = device_ordinal; + return *this; +} + +int ExecutableRunOptions::device_ordinal() const { return device_ordinal_; } + +ExecutableRunOptions& ExecutableRunOptions::set_allocator( + DeviceMemoryAllocator* allocator) { + allocator_ = allocator; + return *this; +} + +DeviceMemoryAllocator* ExecutableRunOptions::allocator() const { + return allocator_; +} + +ExecutableRunOptions& ExecutableRunOptions::set_stream( + perftools::gputools::Stream* stream) { + stream_ = stream; + return *this; +} + +perftools::gputools::Stream* ExecutableRunOptions::stream() const { + return stream_; +} + +ExecutableRunOptions& ExecutableRunOptions::set_inter_op_thread_pool( + tensorflow::thread::ThreadPool* inter_op_thread_pool) { + inter_op_thread_pool_ = inter_op_thread_pool; + return *this; +} + +tensorflow::thread::ThreadPool* ExecutableRunOptions::inter_op_thread_pool() + const { + return inter_op_thread_pool_; +} + +ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( + const Eigen::ThreadPoolDevice* intra_op_thread_pool) { + intra_op_thread_pool_ = intra_op_thread_pool; + return *this; +} + +const Eigen::ThreadPoolDevice* ExecutableRunOptions::intra_op_thread_pool() + const { + return intra_op_thread_pool_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h new file mode 100644 index 0000000000..212fce9eab --- /dev/null +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ + +// Intentionally forward declared so that ExecutableRunOptions can be linked +// into an XLA-compiled binary without having to link all of the pointed-to +// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't +// need to be linked). +namespace perftools { +namespace gputools { +class Stream; +class Platform; +} +} + +namespace tensorflow { +namespace thread { +class ThreadPool; +} +} + +namespace Eigen { +struct ThreadPoolDevice; +} + +namespace xla { + +class DeviceMemoryAllocator; + +// Class containing options for running a LocalExecutable. +class ExecutableRunOptions { + public: + // Specifies the allocator to use during execution. + ExecutableRunOptions& set_allocator(DeviceMemoryAllocator* allocator); + DeviceMemoryAllocator* allocator() const; + + // If set, this is the device to run the computation on. Valid device_ordinal + // values are: 0 to # of devices - 1. These values are identical to the device + // ordinal values used by StreamExecutor. The device must be of the same type + // as the executable was compiled for. A value of -1 indicates this option has + // not been set. + ExecutableRunOptions& set_device_ordinal(int device_ordinal); + int device_ordinal() const; + + // If set, this is the stream to run the computation on. The platform of the + // stream must match the platform the executable was built for. A value of + // nullptr indicates the option has not been set. + ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream); + perftools::gputools::Stream* stream() const; + + // Sets the thread pool on which to run parallel CPU backend + // computations. Does not take ownership. + ExecutableRunOptions& set_inter_op_thread_pool( + tensorflow::thread::ThreadPool* inter_op_thread_pool); + tensorflow::thread::ThreadPool* inter_op_thread_pool() const; + + // Sets the thread pool device on which to run Eigen subcomputations. + // Does not take ownership. + ExecutableRunOptions& set_intra_op_thread_pool( + const Eigen::ThreadPoolDevice* intra_op_thread_pool); + const Eigen::ThreadPoolDevice* intra_op_thread_pool() const; + + private: + DeviceMemoryAllocator* allocator_ = nullptr; + int device_ordinal_ = -1; + perftools::gputools::Stream* stream_ = nullptr; + tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; + const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc new file mode 100644 index 0000000000..901fcd89ea --- /dev/null +++ b/tensorflow/compiler/xla/index_util.cc @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/index_util.h" + +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( + const Shape& shape, tensorflow::gtl::ArraySlice multi_index) { + CHECK_EQ(shape.dimensions_size(), multi_index.size()); + // Padding and nested layouts not supported yet. + CHECK_EQ(0, shape.layout().padded_dimensions_size()); + + for (int i = 0; i < multi_index.size(); ++i) { + CHECK_GE(multi_index[i], 0); + CHECK_LT(multi_index[i], shape.dimensions(i)) + << "indexing beyond extent in dimension " << i << ":" + << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",") + << "\n\tshape: " << ShapeUtil::HumanString(shape); + } + + // Let the array be sized like so for dimensions i from 0 to n-1: + // + // [D{n-1} x D{n-2} x .. x D{0}] + // + // Let the order of the dimensions in the minor_to_major field in + // Layout be: + // + // L(0), L(1), ... , L(n-1) + // + // where L(0) is the most-minor dimension and L(n-1) the most-major. The + // multidimensional index: + // + // [I{0}, I{1}, ... , I{n-1}] + // + // then corresponds to the following linear index: + // + // linear_index = + // ((( ... + I{L(2)}) * D{L(1)} + I{L(1)}) * D{L(0)} + I{L(0)} + // + // or equivalently: + // + // linear_index = + // I{L(n-1)} * (D{L(n-2)} * D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) + + // I{L(n-2)} * (D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) + + // I{L(n-3)} * (D{L(n-4)} * .... D{L(0)}) + + // ... + + // I{L(2)} * (D{L(1)} * D{L(0)}) + + // I{L(1)} * D{L(0)} + + // I{L(0)} + // + // We compute the linear index value by accumulating the terms above from + // I{L(0)} up to I{L(n-1)}. Scale accumulates the product term D{L(0}} * + // D{L(1)} * ... + + // Scale factor holding the growing product of D{L(i)} terms. + int64 scale = 1; + int64 linear_index = 0; + for (auto dimension : shape.layout().minor_to_major()) { + linear_index += scale * multi_index[dimension]; + scale *= shape.dimensions(dimension); + } + return linear_index; +} + +/* static */ std::vector IndexUtil::LinearIndexToMultidimensionalIndex( + const Shape& shape, int64 linear_index) { + // Padding and nested layouts not supported yet. + CHECK_EQ(0, shape.layout().padded_dimensions_size()); + CHECK_GE(linear_index, 0); + CHECK_LT(linear_index, ShapeUtil::ElementsIn(shape)); + + // The following formula computes each element of the multidimensional index + // (See comments in MultidimensionalIndexToLinearIndex for notation): + // + // I{L(0)} = linear_index % D{L(0)} + // I{L(1)} = (linear_index / D{L(0)}) % D{L(1)} + // I{L(2)} = (linear_index / (D{L(0)} * D{L(1)})) % D{L(2)} + // ... + std::vector multi_index(shape.dimensions_size()); + + // Accumulated product D{L(0)} * D{L(1)} * ... + int64 divisor = 1; + for (auto dimension : shape.layout().minor_to_major()) { + multi_index[dimension] = + (linear_index / divisor) % shape.dimensions(dimension); + divisor *= shape.dimensions(dimension); + } + return multi_index; +} + +/* static */ bool IndexUtil::BumpIndices(const Shape& shape, + std::vector* indices) { + for (int64 dimno = indices->size() - 1; dimno >= 0; --dimno) { + int64 limit = shape.dimensions(dimno); + if ((*indices)[dimno] + 1 < limit) { + (*indices)[dimno]++; + std::fill(indices->begin() + dimno + 1, indices->end(), 0); + return true; + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h new file mode 100644 index 0000000000..2d8753c3fe --- /dev/null +++ b/tensorflow/compiler/xla/index_util.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions related to layouts of Shapes. + +#ifndef TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Namespaced collection of (static) utilities related to indexing into +// multidimensional arrays. +class IndexUtil { + public: + // Converts a multidimensional index (eg {x, y, z}) into a linear index based + // on the shape and its layout. The first index in the multi_index is + // dimension 0. + static int64 MultidimensionalIndexToLinearIndex( + const Shape& shape, tensorflow::gtl::ArraySlice multi_index); + + // Coverts a linear index into multidimensional index (eg {x, y, z}) based on + // the shape and its layout. The first index in the returned multidimensional + // index is dimension 0. + static std::vector LinearIndexToMultidimensionalIndex( + const Shape& shape, int64 linear_index); + + // Bumps a sequence of indices; e.g. {0,0,0,0} up by one index value; e.g. to + // {0,0,0,1}. This is akin to std::next_permutation. If the index hits a limit + // for the provided shape, the next most significant index is bumped, in a + // counting-up process. + // + // E.g. for shape f32[2,3] + // {0,0}=>{0,1} + // {0,1}=>{0,2} + // {0,2}=>{1,0} + // etc. + // + // This is useful for traversing the indices in a literal. + // + // Returns true iff the indices were successfully bumped; false if we've hit + // the limit where it can no longer be bumped in-bounds. + static bool BumpIndices(const Shape& shape, std::vector* indices); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_ diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc new file mode 100644 index 0000000000..85259b33f0 --- /dev/null +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/index_util.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +void SetMinorToMajorLayout(Shape* shape, + std::initializer_list dimensions) { + shape->mutable_layout()->clear_minor_to_major(); + for (auto dimension : dimensions) { + shape->mutable_layout()->add_minor_to_major(dimension); + } +} + +TEST(IndexUtilTest, VectorIndexing) { + // Vectors are trivially laid out and the linear index should always be the + // same as the "multidimensional" index. + Shape vector_shape = ShapeUtil::MakeShape(F32, {100}); + EXPECT_EQ(42, + IndexUtil::MultidimensionalIndexToLinearIndex(vector_shape, {42})); + std::vector multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(vector_shape, 42); + EXPECT_EQ(1, multi_index.size()); + EXPECT_EQ(42, multi_index[0]); +} + +TEST(IndexUtilTest, MatrixIndexingRowMajor) { + // Set layout to [0, 1]. That is, row major. + Shape matrix_shape_01 = ShapeUtil::MakeShape(F32, {10, 20}); + SetMinorToMajorLayout(&matrix_shape_01, {0, 1}); + + // If index is {a, b} then linear index should be: a + b * 10 + EXPECT_EQ(0, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01, + {0, 0})); + EXPECT_EQ(199, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01, + {9, 19})); + EXPECT_EQ(53, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01, + {3, 5})); + EXPECT_EQ(std::vector({3, 5}), + IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_01, 53)); +} + +TEST(IndexUtilTest, MatrixIndexingColumnMajor) { + // Set layout to [1, 0]. That is, column major. + Shape matrix_shape_10 = ShapeUtil::MakeShape(F32, {10, 20}); + SetMinorToMajorLayout(&matrix_shape_10, {1, 0}); + + // If index is {a, b} then linear index should be: a * 20 + b + EXPECT_EQ(0, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10, + {0, 0})); + EXPECT_EQ(199, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10, + {9, 19})); + EXPECT_EQ(65, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10, + {3, 5})); + EXPECT_EQ(std::vector({3, 5}), + IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_10, 65)); +} + +TEST(IndexUtilTest, ThreeDArrayIndexing210) { + // Set layout to [2, 1, 0]. That is, column major. + Shape shape_210 = ShapeUtil::MakeShape(F32, {10, 20, 30}); + SetMinorToMajorLayout(&shape_210, {2, 1, 0}); + + // If index is {a, b, c} then linear index should be: + // a * 20 * 30 + b * 30 + c + EXPECT_EQ(1957, IndexUtil::MultidimensionalIndexToLinearIndex(shape_210, + {3, 5, 7})); + EXPECT_EQ(5277, IndexUtil::MultidimensionalIndexToLinearIndex(shape_210, + {8, 15, 27})); +} + +TEST(IndexUtilTest, ThreeDArrayIndexing120) { + // Set layout to [1, 2, 0] + Shape shape_120 = ShapeUtil::MakeShape(F32, {10, 20, 30}); + SetMinorToMajorLayout(&shape_120, {1, 2, 0}); + + // If index is {a, b, c} then linear index should be: + // a * 20 * 30 + b + c * 20 + EXPECT_EQ(1945, IndexUtil::MultidimensionalIndexToLinearIndex(shape_120, + {3, 5, 7})); + EXPECT_EQ(5355, IndexUtil::MultidimensionalIndexToLinearIndex(shape_120, + {8, 15, 27})); +} + +TEST(IndexUtilTest, FourDArrayIndexing3210) { + // Set layout to [3, 2, 1,0]. That is, column major. + Shape shape_3210 = ShapeUtil::MakeShape(F32, {10, 20, 30, 40}); + SetMinorToMajorLayout(&shape_3210, {3, 2, 1, 0}); + + // If index is {a, b, c, d} then linear index should be: + // a * 20 * 30 * 40 + b * 30 * 40 + c * 40 + d + EXPECT_EQ(78289, IndexUtil::MultidimensionalIndexToLinearIndex(shape_3210, + {3, 5, 7, 9})); + EXPECT_EQ(211113, IndexUtil::MultidimensionalIndexToLinearIndex( + shape_3210, {8, 15, 27, 33})); +} + +TEST(IndexUtilTest, LinearToMultiToLinear) { + // Verify that converting a linear index to a multidimensional index and back + // always returns the same value for different crazy shapes. Shape has + // 1440000000 elements. Inputs are randomly-ish selected. + std::vector linear_indexes = {0, 1439999999, 1145567336, + 43883404, 617295214, 1117613654}; + + std::vector> minor_to_major_orders; + minor_to_major_orders.push_back({6, 5, 4, 3, 2, 1, 0}); + minor_to_major_orders.push_back({0, 1, 2, 3, 4, 5, 6}); + minor_to_major_orders.push_back({4, 5, 1, 2, 6, 0, 3}); + + for (auto minor_to_major_order : minor_to_major_orders) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30, 40, 30, 20, 10}); + SetMinorToMajorLayout(&shape, minor_to_major_order); + for (auto linear_index : linear_indexes) { + std::vector multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); + EXPECT_EQ(linear_index, IndexUtil::MultidimensionalIndexToLinearIndex( + shape, multi_index)); + } + } +} + +TEST(IndexUtilTest, BumpIndices2x2) { + auto shape = ShapeUtil::MakeShape(S32, {2, 2}); + std::vector indices = {0, 0}; + EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_MATCH(indices, + testing::VectorMatcher(std::vector{0, 1})); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_MATCH(indices, + testing::VectorMatcher(std::vector{1, 0})); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_MATCH(indices, + testing::VectorMatcher(std::vector{1, 1})); + EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc new file mode 100644 index 0000000000..81eb717821 --- /dev/null +++ b/tensorflow/compiler/xla/layout_util.cc @@ -0,0 +1,363 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout_util.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { +namespace { + +using DimensionOrder = legacy_flags::DefaultLayout::DimensionOrder; + +// Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets +// minor_to_major to the value that represents the default layout. +void SetDefaultLayoutToContainer( + tensorflow::protobuf::RepeatedField* + minor_to_major) { + const int64 size = minor_to_major->size(); + legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags(); + auto default_layout = flags->xla_default_layout; + switch (default_layout.dimension_order) { + case DimensionOrder::kMajorToMinor: + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, size - 1 - i); + } + break; + case DimensionOrder::kMinorToMajor: + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, i); + } + break; + case DimensionOrder::kRandom: + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, i); + } + std::shuffle( + minor_to_major->begin(), minor_to_major->end(), + std::mt19937(default_layout.seed != 0 ? default_layout.seed + : std::random_device()())); + } +} + +} // namespace + +/* static */ Layout LayoutUtil::MakeLayout( + tensorflow::gtl::ArraySlice minor_to_major) { + Layout layout; + for (int64 dimension_number : minor_to_major) { + layout.add_minor_to_major(dimension_number); + } + return layout; +} + +namespace { + +// Internal helper that creates a default layout for an array of the given rank. +Layout CreateDefaultLayoutForRank(int64 rank) { + Layout layout; + tensorflow::protobuf::RepeatedField* + minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->Resize(rank, 0); + SetDefaultLayoutToContainer(minor_to_major); + return layout; +} + +} // namespace + +/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + // A Layout proto corresponds to a single array, not a tuple. + DCHECK(!ShapeUtil::IsTuple(shape)); + return CreateDefaultLayoutForRank(shape.dimensions_size()); +} + +/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { + return CreateDefaultLayoutForRank(2); +} + +/* static */ Layout LayoutUtil::GetDefaultLayoutForR3() { + return CreateDefaultLayoutForRank(3); +} + +/* static */ Layout LayoutUtil::GetDefaultLayoutForR4() { + return CreateDefaultLayoutForRank(4); +} + +/* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) { + if (ShapeUtil::IsTuple(*shape)) { + // Tuple shape. + for (auto& element_shape : *shape->mutable_tuple_shapes()) { + SetToDefaultLayout(&element_shape); + } + } else { + tensorflow::protobuf::RepeatedField* + minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->Resize(shape->dimensions_size(), 0); + SetDefaultLayoutToContainer(minor_to_major); + } +} + +/* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) { + for (auto& parameter_shape : *program_shape->mutable_parameters()) { + LayoutUtil::SetToDefaultLayout(¶meter_shape); + } + LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); +} + +/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( + const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + // Tuple shape. + if (shape.has_layout()) { + return InvalidArgument("tuple should not have a layout field"); + } + for (auto& element_shape : shape.tuple_shapes()) { + TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); + } + return tensorflow::Status::OK(); + } else if (ShapeUtil::Rank(shape) == 0 && !shape.has_layout()) { + // A scalar without a layout is ok. + return tensorflow::Status::OK(); + } else { + // Array shape. + if (!shape.has_layout()) { + return InvalidArgument("shape does not have a layout"); + } + return ValidateLayoutForShape(shape.layout(), shape); + } +} + +/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( + const Layout& layout, const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + return InvalidArgument("a single Layout is not valid for tuple shapes"); + } + + if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + return InvalidArgument( + "layout minor_to_major field contains %d elements, " + "but shape is rank %lld: {%s}; shape: %s", + layout.minor_to_major_size(), ShapeUtil::Rank(shape), + tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), + shape.ShortDebugString().c_str()); + } + + std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); + for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + int64 dim = layout.minor_to_major(i); + if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + return InvalidArgument( + "layout minor_to_major field has out-of-bounds value: %s", + HumanString(layout).c_str()); + } + if (dimensions_in_layout[dim]) { + return InvalidArgument( + "layout minor_to_major field has duplicate values: {%s}", + HumanString(layout).c_str()); + } + dimensions_in_layout[dim] = true; + } + + if (layout.padded_dimensions_size() > 0) { + if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { + return InvalidArgument( + "layout has %d padded dimensions, but shape is rank %lld", + layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + } + for (int i = 0; i < layout.padded_dimensions_size(); ++i) { + if (layout.padded_dimensions(i) < shape.dimensions(i)) { + return InvalidArgument( + "for dimension %d, dimension padding (%lld) is smaller than " + "the dimension size (%lld) of the shape", + i, layout.padded_dimensions(i), shape.dimensions(i)); + } + } + } + return tensorflow::Status::OK(); +} + +/* static */ void LayoutUtil::ClearLayout(Shape* shape) { + shape->clear_layout(); + for (auto& element_shape : *shape->mutable_tuple_shapes()) { + ClearLayout(&element_shape); + } +} + +/* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) { + for (auto& parameter_shape : *program_shape->mutable_parameters()) { + LayoutUtil::ClearLayout(¶meter_shape); + } + LayoutUtil::ClearLayout(program_shape->mutable_result()); +} + +/* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) { + return std::is_sorted(layout.minor_to_major().begin(), + layout.minor_to_major().end()); +} + +/* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) { + return std::is_sorted(layout.minor_to_major().begin(), + layout.minor_to_major().end(), std::greater()); +} + +/* static */ bool LayoutUtil::IsPadded(const Shape& shape) { + if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + shape.layout().padded_dimensions_size() == 0) { + return false; + } + CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { + return true; + } + } + return false; +} + +/* static */ bool LayoutUtil::HasLayout(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + // Tuple shape: all subshapes must have a layout. + return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), + [](const Shape& s) { return HasLayout(s); }); + } + // A scalar trivially always has a layout. + return (ShapeUtil::Rank(shape) == 0 || + (shape.has_layout() && (shape.layout().minor_to_major_size() > 0))); +} + +/* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) { + for (auto& parameter_shape : program_shape.parameters()) { + if (!LayoutUtil::HasLayout(parameter_shape)) { + return false; + } + } + return LayoutUtil::HasLayout(program_shape.result()); +} + +/* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { + return protobuf_util::ProtobufEquals(lhs, rhs); +} + +/* static */ int64 LayoutUtil::Major(const Layout& layout, + int64 physical_dimension_number) { + CHECK_LE(0, physical_dimension_number); + CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); + return Minor(layout, + layout.minor_to_major_size() - 1 - physical_dimension_number); +} + +/* static */ int64 LayoutUtil::Minor(const Layout& layout, + int64 physical_dimension_number) { + CHECK_LE(0, physical_dimension_number); + CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); + return layout.minor_to_major(physical_dimension_number); +} + +/* static */ std::vector LayoutUtil::MakeLogicalToPhysical( + const Layout& layout) { + std::vector logical_to_physical(layout.minor_to_major_size()); + for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) { + const int64 logical = Major(layout, physical); + logical_to_physical[logical] = physical; + } + return logical_to_physical; +} + +/* static */ string LayoutUtil::HumanString(const Layout& layout) { + return tensorflow::strings::StrCat( + "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); +} + +namespace { + +// Internal helper for recursively copying layouts. +tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { + if (ShapeUtil::IsTuple(src)) { + DCHECK(ShapeUtil::IsTuple(*dst)); + DCHECK_EQ(ShapeUtil::TupleElementCount(src), + ShapeUtil::TupleElementCount(*dst)); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) { + TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i), + dst->mutable_tuple_shapes(i))); + } + } else { + if (src.has_layout()) { + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(src.layout(), *dst)); + *dst->mutable_layout() = src.layout(); + } else { + dst->clear_layout(); + } + } + return tensorflow::Status::OK(); +} + +} // namespace + +/* static */ +tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, + Shape* dst) { + if (!ShapeUtil::Compatible(src, *dst)) { + return InvalidArgument( + "cannot copy layout from shape %s to shape %s: " + "shapes are not compatible", + ShapeUtil::HumanString(src).c_str(), + ShapeUtil::HumanString(*dst).c_str()); + } + return CopyLayoutInternal(src, dst); +} + +/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, + const Shape& rhs) { + if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { + return false; + } + if (ShapeUtil::IsTuple(lhs)) { + if (ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { + return false; + } + for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { + if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) { + return false; + } + } + return true; + } else { + return ShapeUtil::SameDimensions(lhs, rhs) && + LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h new file mode 100644 index 0000000000..984bf402cd --- /dev/null +++ b/tensorflow/compiler/xla/layout_util.h @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions related to layouts of Shapes. + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Namespaced collection of (static) Layout utilities. +class LayoutUtil { + public: + // Creates a layout with the given minor-to-major dimension order. (This is a + // convenience function for protobuf construction.) + static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + + // Returns default layout for the given shape. + static Layout GetDefaultLayoutForShape(const Shape& shape); + + // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForR2(); + static Layout GetDefaultLayoutForR3(); + static Layout GetDefaultLayoutForR4(); + + // Sets the default layout on the Shape. + static void SetToDefaultLayout(Shape* shape); + + // Sets the layouts of all Shapes within the given ProgramShape to the + // default. + static void SetToDefaultLayout(ProgramShape* program_shape); + + // Validates that the layout within the given shape is correct. + static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + + // Validates that the provided layout satisfies invariants for the given + // shape. + static tensorflow::Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); + + // Clears the layout in the given Shape. After this function is called, + // HasLayout will return false for the shape. + static void ClearLayout(Shape* shape); + + // Clears the layout on all Shapes within the given ProgramShape. + static void ClearLayout(ProgramShape* program_shape); + + // Returns whether the layout is monotonic and dim 0 is minor in the layout. + // * R0 and R1: this is always trivially true. + // * R2+: equivalent to column-major. Dimension 0 is the minor, dimension 1 is + // more major, and so on until dimension N-1 which is the major. + static bool IsMonotonicWithDim0Minor(const Layout& layout); + + // Returns whether the layout is monotonic and dim 0 is major in the layout. + // * R0 and R1: this is always trivially true. + // * R2+: equivalent to row-major. Dimension 0 is the major, dimension 1 is + // more minor, and so on until dimension N-1 which is the minor. + static bool IsMonotonicWithDim0Major(const Layout& layout); + + // Returns whether the layout of the given shape has padding (a + // padded_dimension value in Layout is greater than the corresponding + // dimension size). + static bool IsPadded(const Shape& shape); + + // Returns whether the given shape has a layout. For tuple shapes, true is + // returned only if all elements have layouts. + static bool HasLayout(const Shape& shape); + + // Returns whether all Shapes within the given ProgramShape have layouts. + static bool HasLayout(const ProgramShape& program_shape); + + // Returns whether lhs and rhs are identical. + static bool Equal(const Layout& lhs, const Layout& rhs); + + // Major(0) is the most major logical dimension number, major(1) is the + // second-most-major logical dimension number and so on. + // + // This can be used to translate physical dimension numbers to logical + // dimension numbers. Assume that we are numbering the physical dimensions so + // that the most major physical dimension has physical dimension number 0 and + // so on. Then a physical dimension number p corresponds to the logical + // dimension number Major(p). So this function could also be called + // PhysicalToLogical(). + // + // As an example, consider physical dimension number 0, which by definition is + // the most major. Then Major(0) is the most major logical dimension, so Major + // maps the physical dimension number 0 to the most major logical dimension + // number Major(0). + static int64 Major(const Layout& layout, int64 physical_dimension_number); + + // Minor(0) is the most minor logical dimension number, minor(1) is the + // second-most-minor logical dimension number and so on. + static int64 Minor(const Layout& layout, int64 physical_dimension_number); + + // Returns the inverse mapping of the Major() function. More precisely, return + // a vector v such that if l == Major(p), then v[l] == p. + // + // This can be used to translate logical dimension numbers into physical + // dimension numbers. Assume that we are numbering the physical dimensions so + // that the most major physical dimension has physical dimension number 0 and + // so on. Then a logical dimension number l corresponds to the physical + // dimension number MakeLogicalToPhysical(layout)[l]. + // + // As an example, consider physical dimension number 0, which by definition is + // the most major. Then l := Major(0) is the most major logical dimension. If + // v is the vector returned from this function, then v[l] == 0. So v maps the + // most major logical dimension l to the physical dimension number 0. + static std::vector MakeLogicalToPhysical(const Layout& layout); + + // Returns a human-readable string that represents the given layout. + static string HumanString(const Layout& layout); + + // Copies the layout from 'src' to 'dst'. Recursively copies layouts of + // tuples. 'src' and 'dst' must be compatible. + static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, + Shape* dst); + + // Returns true if the layouts of lhs and rhs are equal, false + // otherwise. Recursively compares layouts of tuples. + // + // Since the structure of the shape determines the structure of the layout, + // this returns true if and only if the shapes and layouts are identical + // except that the element type is ignored. + static bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc new file mode 100644 index 0000000000..5ba1122946 --- /dev/null +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -0,0 +1,246 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class LayoutUtilTest : public ::testing::Test { + protected: + Shape MakeShapeWithLayout(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + return shape; + } +}; + +TEST_F(LayoutUtilTest, TupleLayoutComparison) { + Shape shape = + ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1})}); + Shape other_shape = + ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + + Shape tuple0 = ShapeUtil::MakeTupleShape({}); + Shape tuple1 = ShapeUtil::MakeTupleShape({shape}); + Shape tuple2 = ShapeUtil::MakeTupleShape({shape, shape}); + + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple0)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple1)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple2)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple0)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple0)); + + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple1)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple2)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple1)); + + Shape other_tuple2 = ShapeUtil::MakeTupleShape({shape, other_shape}); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple2)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, other_tuple2)); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(other_tuple2, tuple2)); +} + +TEST_F(LayoutUtilTest, CopyLayoutArray) { + Shape src = MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // Should work if destination has no layout. + dst.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // If source is cleared, then destination should be cleared. + src.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_TRUE(dst.has_layout()); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_FALSE(dst.has_layout()); +} + +TEST_F(LayoutUtilTest, CopyLayoutTuple) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyLayoutNotCompatible) { + Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); + Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_MATCH(status.error_message(), + testing::ContainsRegex("cannot copy layout from shape")); +} + +TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { + Shape src = + ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape({MakeShapeWithLayout( + F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_MATCH(status.error_message(), + testing::ContainsRegex("cannot copy layout from shape")); +} + +TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { + Shape src = ShapeUtil::MakeShape(F32, {2, 3}); + Shape dst = ShapeUtil::MakeShape(F32, {2, 3}); + // Set layout to invalid value. + *src.mutable_layout() = LayoutUtil::MakeLayout({1, 2, 3, 4}); + + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_MATCH(status.error_message(), + testing::ContainsRegex("layout minor_to_major field contains .* " + "elements, but shape is rank")); +} + +TEST_F(LayoutUtilTest, ClearLayoutTuple) { + Shape shape = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + EXPECT_TRUE(shape.tuple_shapes(0).has_layout()); + EXPECT_TRUE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); + + LayoutUtil::ClearLayout(&shape); + + EXPECT_FALSE(LayoutUtil::HasLayout(shape)); + EXPECT_FALSE(shape.tuple_shapes(0).has_layout()); + EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); +} + +TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { + Shape shape = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), + MakeShapeWithLayout(F32, {42, 123, 7}, {1, 2, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 1, 2, 0})})}); + EXPECT_FALSE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(), + shape.tuple_shapes(1).layout())); + LayoutUtil::SetToDefaultLayout(&shape); + EXPECT_TRUE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(), + shape.tuple_shapes(1).layout())); + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::GetDefaultLayoutForShape(shape.tuple_shapes(0)), + shape.tuple_shapes(1).layout())); +} + +TEST_F(LayoutUtilTest, IsPadded) { + Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); + LayoutUtil::ClearLayout(&shape_without_layout); + EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout)); + + Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); + LayoutUtil::SetToDefaultLayout(&shape_with_layout); + EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout)); + + // Add padding equal to the dimension sizes. In this case the padding is a + // nop. + Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); + shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2); + shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3); + shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4); + EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding)); + + Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); + shape_with_padding.mutable_layout()->add_padded_dimensions(2); + shape_with_padding.mutable_layout()->add_padded_dimensions(14); + shape_with_padding.mutable_layout()->add_padded_dimensions(42); + EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding)); +} + +TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { + // Test that LayoutUtil returns expected layouts when the xla_default_layout + // flag is set to kMajorToMinor. + legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags(); + flags->xla_default_layout = xla::legacy_flags::DefaultLayout{ + .dimension_order = + legacy_flags::DefaultLayout::DimensionOrder::kMajorToMinor}; + + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), + LayoutUtil::GetDefaultLayoutForR2())); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}), + LayoutUtil::GetDefaultLayoutForR3())); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({3, 2, 1, 0}), + LayoutUtil::GetDefaultLayoutForR4())); + EXPECT_TRUE( + LayoutUtil::Equal(LayoutUtil::MakeLayout({4, 3, 2, 1, 0}), + LayoutUtil::GetDefaultLayoutForShape( + ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); +} + +TEST_F(LayoutUtilTest, DefaultLayoutGettersMinorToMajor) { + // Test that LayoutUtil returns expected layouts when the xla_default_layout + // flag is set to kMinorToMajor. + legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags(); + flags->xla_default_layout = xla::legacy_flags::DefaultLayout{ + .dimension_order = + legacy_flags::DefaultLayout::DimensionOrder::kMinorToMajor}; + + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), + LayoutUtil::GetDefaultLayoutForR2())); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), + LayoutUtil::GetDefaultLayoutForR3())); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2, 3}), + LayoutUtil::GetDefaultLayoutForR4())); + EXPECT_TRUE( + LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2, 3, 4}), + LayoutUtil::GetDefaultLayoutForShape( + ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD new file mode 100644 index 0000000000..c98232cdf6 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -0,0 +1,267 @@ +# Legacy command line flags for the XLA libraries. + +# Please do not add more flags to this package. + +# The XLA libraries were written in an environment that allowed command - line +# flags to be scattered freely throughout the libraries. This model, while +# initially convenient, leads to a proliferation in unused commnd line flags in +# tests and binaries, and serious problems in servers, where one might wish +# parameters to be different in independent RPC calls to the same routine. +# +# Please don't add more flags. If you're a library author, pass options and +# parameters explicitly through the library's interface. + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "parse_flags_from_env", + srcs = ["parse_flags_from_env.cc"], + hdrs = ["parse_flags_from_env.h"], + deps = + [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "parse_flags_from_env_test", + srcs = ["parse_flags_from_env_test.cc"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "layout_util_flags", + srcs = ["layout_util_flags.cc"], + hdrs = ["layout_util_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "util_flags", + srcs = ["util_flags.cc"], + hdrs = ["util_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "cpu_compiler_flags", + srcs = ["cpu_compiler_flags.cc"], + hdrs = ["cpu_compiler_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "cpu_runtime_flags", + srcs = ["cpu_runtime_flags.cc"], + hdrs = ["cpu_runtime_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "llvm_backend_flags", + srcs = ["llvm_backend_flags.cc"], + hdrs = ["llvm_backend_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "compiler_functor_flags", + srcs = ["compiler_functor_flags.cc"], + hdrs = ["compiler_functor_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "convolution_thunk_flags", + srcs = ["convolution_thunk_flags.cc"], + hdrs = ["convolution_thunk_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gpu_compiler_flags", + srcs = ["gpu_compiler_flags.cc"], + hdrs = ["gpu_compiler_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gpu_backend_lib_flags", + srcs = ["gpu_backend_lib_flags.cc"], + hdrs = ["gpu_backend_lib_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "stream_assignment_flags", + srcs = ["stream_assignment_flags.cc"], + hdrs = ["stream_assignment_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_graph_dumper_flags", + srcs = ["hlo_graph_dumper_flags.cc"], + hdrs = ["hlo_graph_dumper_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_pass_pipeline_flags", + srcs = ["hlo_pass_pipeline_flags.cc"], + hdrs = ["hlo_pass_pipeline_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "alias_analysis_flags", + srcs = ["alias_analysis_flags.cc"], + hdrs = ["alias_analysis_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "llvm_util_flags", + srcs = ["llvm_util_flags.cc"], + hdrs = ["llvm_util_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "service_flags", + srcs = ["service_flags.cc"], + hdrs = ["service_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "buffer_assignment_flags", + srcs = ["buffer_assignment_flags.cc"], + hdrs = ["buffer_assignment_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_test_base_flags", + srcs = ["hlo_test_base_flags.cc"], + hdrs = ["hlo_test_base_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "backend_flags", + srcs = ["backend_flags.cc"], + hdrs = ["backend_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc new file mode 100644 index 0000000000..474753c10a --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's alias_analysis module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static AliasAnalysisFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new AliasAnalysisFlags; + flags->xla_emit_alias_scope = true; + flag_list = new std::vector({ + tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope, + "Use buffer analysis to refine alias-analysis."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's alias_analysis +// module. +void AppendAliasAnalysisFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the AliasAnalysisFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +AliasAnalysisFlags* GetAliasAnalysisFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h new file mode 100644 index 0000000000..369f8cd7ca --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ + +// Legacy flags for XLA's alias_analysis module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's alias_analysis +// module. +void AppendAliasAnalysisFlags(std::vector* flag_list); + +// The values of flags associated with XLA's alias_analysis module. +typedef struct { + bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis. +} AliasAnalysisFlags; + +// Return a pointer to the AliasAnalysisFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +AliasAnalysisFlags* GetAliasAnalysisFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/backend_flags.cc b/tensorflow/compiler/xla/legacy_flags/backend_flags.cc new file mode 100644 index 0000000000..7c007f4435 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/backend_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's backend module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static BackendFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new BackendFlags; + // TODO(b/32648682): Decide if this should continue to be a flag longer term. + flags->xla_replicas = 1; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_replicas", &flags->xla_replicas, + "The number of replicas to use. 1 means no replication."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's backend module. +void AppendBackendFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the BackendFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +BackendFlags* GetBackendFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/backend_flags.h b/tensorflow/compiler/xla/legacy_flags/backend_flags.h new file mode 100644 index 0000000000..061238b7e6 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/backend_flags.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_ + +// Legacy flags for XLA's backend module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's backend module. +void AppendBackendFlags(std::vector* flag_list); + +// The values of flags associated with XLA's backend module. +typedef struct { + int64 xla_replicas; // The number of replicas to use. 1 means no + // replication. +} BackendFlags; + +// Return a pointer to the BackendFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +BackendFlags* GetBackendFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc new file mode 100644 index 0000000000..71873f73af --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's buffer_assignment module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static BufferAssignmentFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new BufferAssignmentFlags; + flags->xla_enable_buffer_reuse = true; + flag_list = new std::vector({ + tensorflow::Flag("xla_enable_buffer_reuse", + &flags->xla_enable_buffer_reuse, + "Enable reuse of buffers."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's buffer_assignment +// module. +void AppendBufferAssignmentFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the BufferAssignmentFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +BufferAssignmentFlags* GetBufferAssignmentFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h new file mode 100644 index 0000000000..5f098c2663 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ + +// Legacy flags for XLA's buffer_assignment module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's buffer_assignment +// module. +void AppendBufferAssignmentFlags(std::vector* flag_list); + +// The values of flags associated with XLA's buffer_assignment module. +typedef struct { + bool xla_enable_buffer_reuse; // Enable reuse of buffers. +} BufferAssignmentFlags; + +// Return a pointer to the BufferAssignmentFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +BufferAssignmentFlags* GetBufferAssignmentFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc new file mode 100644 index 0000000000..617a9b712e --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's compiler_functor module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static CompilerFunctorFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new CompilerFunctorFlags; + flag_list = new std::vector({ + tensorflow::Flag("xla_debug_cpu_dump_ir", &flags->xla_debug_cpu_dump_ir, + "Dump IR, before optimizations to a path"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's compiler_functor +// module. +void AppendCompilerFunctorFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the CompilerFunctorFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +CompilerFunctorFlags* GetCompilerFunctorFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h new file mode 100644 index 0000000000..28b505ec5e --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ + +// Legacy flags for the XLA's compiler_functor module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's compiler_functor +// module. +void AppendCompilerFunctorFlags(std::vector* flag_list); + +// The values of flags associated with XLA's compiler_functor module. +typedef struct { + string xla_debug_cpu_dump_ir; // Dump IR, before optimizations to a path +} CompilerFunctorFlags; + +// Return a pointer to the CompilerFunctorFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +CompilerFunctorFlags* GetCompilerFunctorFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc new file mode 100644 index 0000000000..fe5d19147f --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's convolution_thunk module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static ConvolutionThunkFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new ConvolutionThunkFlags; + flags->xla_gpu_autotune_convolution_algorithm = true; + flag_list = new std::vector({ + tensorflow::Flag("xla_gpu_autotune_convolution_algorithm", + &flags->xla_gpu_autotune_convolution_algorithm, + "Auto-tune the algorithm used by convolution"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's convolution_thunk +// module. +void AppendConvolutionThunkFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the ConvolutionThunkFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +ConvolutionThunkFlags* GetConvolutionThunkFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h new file mode 100644 index 0000000000..53d6806a71 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ + +// Legacy flags for XLA's convolution_thunk module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's convolution_thunk +// module. +void AppendConvolutionThunkFlags(std::vector* flag_list); + +// The values of flags associated with XLA's convolution_thunk module. +typedef struct { + // Auto-tune the algorithm used by convolution + bool xla_gpu_autotune_convolution_algorithm; +} ConvolutionThunkFlags; + +// Return a pointer to the ConvolutionThunkFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +ConvolutionThunkFlags* GetConvolutionThunkFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc new file mode 100644 index 0000000000..f8ae25552d --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's cpu_compiler module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static CpuCompilerFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new CpuCompilerFlags; + flags->xla_cpu_llvm_opt_level = 2; + flags->xla_cpu_llvm_cl_opts = ""; + flags->xla_cpu_embed_ir = false; + flags->xla_cpu_parallel = false; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_cpu_llvm_opt_level", &flags->xla_cpu_llvm_opt_level, + "The LLVM optimization level for the CPU XLA backend. " + "Valid range is from 0 to 3 where 0 means no optimizations."), + tensorflow::Flag( + "xla_cpu_llvm_cl_opts", &flags->xla_cpu_llvm_cl_opts, + "Comma-separated list of command line options to pass to LLVM."), + tensorflow::Flag( + "xla_cpu_embed_ir", &flags->xla_cpu_embed_ir, + "Embed the LLVM IR module string in the resultant CpuExecutable."), + tensorflow::Flag("xla_cpu_parallel", &flags->xla_cpu_parallel, + "Use the multi-threaded CPU backend."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's cpu_compiler +// module. +void AppendCpuCompilerFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the CpuCompilerFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +CpuCompilerFlags* GetCpuCompilerFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h new file mode 100644 index 0000000000..16a7b68711 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ + +// Legacy flags for the XLA's cpu_compiler module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's cpu_compiler +// module. +void AppendCpuCompilerFlags(std::vector* flag_list); + +// The values of flags associated with XLA's cpu_compiler module. +typedef struct { + // The LLVM optimization level for the CPU XLA backend. + // Valid range is from 0 to 3 where 0 means no optimizations. + int32 xla_cpu_llvm_opt_level; + string xla_cpu_llvm_cl_opts; // Comma-separated list of command line options + // to pass to LLVM. + bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant + // CpuExecutable + bool xla_cpu_parallel; // Use the multi-threaded CPU backend. +} CpuCompilerFlags; + +// Return a pointer to the CpuCompilerFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +CpuCompilerFlags* GetCpuCompilerFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc new file mode 100644 index 0000000000..d7817c5d54 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's cpu_runtime module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static CpuRuntimeFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new CpuRuntimeFlags; + flags->xla_cpu_use_eigen = true; + flags->xla_cpu_multi_thread_eigen = true; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_cpu_use_eigen", &flags->xla_cpu_use_eigen, + "Use Eigen for matrix multiply on the CPU platform. This " + "is a useful hack for performance comparisons against " + "XLA's implementation."), + tensorflow::Flag( + "xla_cpu_multi_thread_eigen", &flags->xla_cpu_multi_thread_eigen, + "When generating calls to Eigen for matmul and conv, should " + "single or multi-threaded eigen be used? " + "Only used when --xla_cpu_use_eigen is true."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's cpu_runtime +// module. +void AppendCpuRuntimeFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the CpuRuntimeFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +CpuRuntimeFlags* GetCpuRuntimeFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h new file mode 100644 index 0000000000..e3ff30da36 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ + +// Legacy flags for the XLA's cpu_runtime module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's cpu_runtime +// module. +void AppendCpuRuntimeFlags(std::vector* flag_list); + +// The values of flags associated with XLA's cpu_runtime module. +typedef struct { + // Use Eigen for matrix multiply on the CPU platform. This is a useful hack + // for performance comparisons against XLA's implementation. + bool xla_cpu_use_eigen; + // When generating calls to Eigen for matmul and conv, should single or + // multi-threaded eigen be used? Only used when --xla_cpu_use_eigen is true. + bool xla_cpu_multi_thread_eigen; +} CpuRuntimeFlags; + +// Return a pointer to the CpuRuntimeFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +CpuRuntimeFlags* GetCpuRuntimeFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc new file mode 100644 index 0000000000..c355b1ed9b --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's gpu_backend_lib module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static GpuBackendLibFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new GpuBackendLibFlags; + flags->dump_temp_products_to = ""; + flags->ftz = false; + flags->fma = true; + flags->gpu_architecture = "compute_35"; + flags->verbose_ptx_asm = false; + flags->kernel = ""; + flags->llvm_dump_passes = false; + flags->llvm_cl_opts = ""; + flags->dump_ir_before_passes = false; + flags->opt_level = 3; + flag_list = new std::vector({ + tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to, + "dump temporary compilation products to this directory. " + "If empty, no dump is produced"), + tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"), + tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"), + tensorflow::Flag("gpu_architecture", &flags->gpu_architecture, + "GPU architecture"), + tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm, + "emit PTX assembly with extra comments"), + tensorflow::Flag("kernel", &flags->kernel, + "only emit the IR and PTX for this kernel"), + tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes, + "dump the passes LLVM runs to stderr"), + tensorflow::Flag( + "llvm_cl_opts", &flags->llvm_cl_opts, + "comma-separated list of command line options to pass to " + "LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"), + tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes, + "dump the IR before each optimization pass in " + "sequentially-named files."), + tensorflow::Flag("opt_level", &flags->opt_level, + "optimization level (default to 3)"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's gpu_backend_lib +// module. +void AppendGpuBackendLibFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the GpuBackendLibFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +GpuBackendLibFlags* GetGpuBackendLibFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h new file mode 100644 index 0000000000..fbb8863454 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ + +// Legacy flags for XLA's gpu_backend_lib module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib +// module. +void AppendGpuBackendLibFlags(std::vector* flag_list); + +// The values of flags associated with XLA's gpu_backend_lib module. +typedef struct { + string dump_temp_products_to; // temporary compilation products dir + bool ftz; // flush to zero semantics + bool fma; // use FMA synthesis + string gpu_architecture; // GPU architecture + bool verbose_ptx_asm; // emit PTX assembly with extra comments + string kernel; // only emit the IR and PTX for this kernel + bool llvm_dump_passes; // dump the passes LLVM runs to stderr + string llvm_cl_opts; // comma-separated list of LLVM options + bool dump_ir_before_passes; // dump IR before each pass + int32 opt_level; // optimization level +} GpuBackendLibFlags; + +// Return a pointer to the GpuBackendLibFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +GpuBackendLibFlags* GetGpuBackendLibFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc new file mode 100644 index 0000000000..e79d363509 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's gpu_compiler module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static GpuCompilerFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new GpuCompilerFlags; + flags->xla_gpu_embed_ir = false; + flags->xla_cuda_data_dir = "./cuda_sdk_lib"; + flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas"; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, + "Embed the LLVM IR module string in the resultant GpuExecutable."), + tensorflow::Flag( + "xla_cuda_data_dir", &flags->xla_cuda_data_dir, + "If non-empty, specifies a local directory containing ptxas and " + "nvvm libdevice files. Otherwise, by default, we use those from " + "runfile directories."), + tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path, + "The path to ptxas. Required to log stats of the ptx."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's gpu_compiler +// module. +void AppendGpuCompilerFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the GpuCompilerFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +GpuCompilerFlags* GetGpuCompilerFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h new file mode 100644 index 0000000000..04ddedab73 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ + +// Legacy flags for XLA's gpu_compiler module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's gpu_compiler +// module. +void AppendGpuCompilerFlags(std::vector* flag_list); + +// The values of flags associated with XLA's gpu_compiler module. +typedef struct { + bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant + // GpuExecutable. + string xla_cuda_data_dir; // If non-empty, specifies a local directory + // containing ptxas and nvvm libdevice files. + // Otherwise, by default, we use those from runfile + // directories. + string xla_ptxas_path; // The path to ptxas. Required to log stats of + // the ptx. +} GpuCompilerFlags; + +// Return a pointer to the GpuCompilerFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +GpuCompilerFlags* GetGpuCompilerFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc new file mode 100644 index 0000000000..8822f6f610 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's hlo_graph_dumper module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static HloGraphDumperFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new HloGraphDumperFlags; + flags->xla_hlo_dump_graph_path = "/tmp/"; + flag_list = new std::vector({ + tensorflow::Flag("xla_hlo_dump_graph_path", + &flags->xla_hlo_dump_graph_path, + "Path to write dumped HLO graphs to"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's hlo_graph_dumper +// module. +void AppendHloGraphDumperFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the HloGraphDumperFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +HloGraphDumperFlags* GetHloGraphDumperFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h new file mode 100644 index 0000000000..b6dfced87c --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ + +// Legacy flags for XLA's hlo_graph_dumper module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's hlo_graph_dumper +// module. +void AppendHloGraphDumperFlags(std::vector* flag_list); + +// The values of flags associated with XLA's hlo_graph_dumper module. +typedef struct { + string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to +} HloGraphDumperFlags; + +// Return a pointer to the HloGraphDumperFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +HloGraphDumperFlags* GetHloGraphDumperFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc new file mode 100644 index 0000000000..edc04d51a7 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's hlo_pass_pipeline module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static HloPassPipelineFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new HloPassPipelineFlags; + flags->xla_disable_hlo_passes = ""; + flag_list = new std::vector({ + tensorflow::Flag("xla_disable_hlo_passes", &flags->xla_disable_hlo_passes, + "Comma-separated list of HLO passes to disable."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline +// module. +void AppendHloPassPipelineFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the HloPassPipelineFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +HloPassPipelineFlags* GetHloPassPipelineFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h new file mode 100644 index 0000000000..520759bbf0 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ + +// Legacy flags for XLA's hlo_pass_pipeline module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's hlo_pass_pipeline +// module. +void AppendHloPassPipelineFlags(std::vector* flag_list); + +// The values of flags associated with XLA's hlo_pass_pipeline module. +typedef struct { + // Comma-separated list of HLO passes to disable. + string xla_disable_hlo_passes; +} HloPassPipelineFlags; + +// Return a pointer to the HloPassPipelineFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +HloPassPipelineFlags* GetHloPassPipelineFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc new file mode 100644 index 0000000000..c7893c1385 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's hlo_test_base module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static HloTestBaseFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new HloTestBaseFlags; + flags->xla_hlo_test_generate_hlo_graph = false; + flag_list = new std::vector({ + tensorflow::Flag("xla_hlo_test_generate_hlo_graph", + &flags->xla_hlo_test_generate_hlo_graph, + "Generate graph output of HLO instructions"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's hlo_test_base +// module. +void AppendHloTestBaseFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the HloTestBaseFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +HloTestBaseFlags* GetHloTestBaseFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h new file mode 100644 index 0000000000..23b808cecb --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ + +// Legacy flags for XLA's hlo_test_base module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's hlo_test_base +// module. +void AppendHloTestBaseFlags(std::vector* flag_list); + +// The values of flags associated with XLA's hlo_test_base module. +typedef struct { + bool xla_hlo_test_generate_hlo_graph; // Generate graph output of HLO + // instructions +} HloTestBaseFlags; + +// Return a pointer to the HloTestBaseFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +HloTestBaseFlags* GetHloTestBaseFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc new file mode 100644 index 0000000000..4242b501d4 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's layout_util module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the string value of the xla_default_layout flag and the flag +// descriptor, initialized via raw_flags_init. +static string* raw_flag; +static std::vector* flag_list; +static std::once_flag raw_flags_init; + +// Allocate *raw_flag. Called via call_once(&raw_flags_init,...). +static void AllocateRawFlag() { + raw_flag = new string; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_default_layout", raw_flag, + "Default layout for Shapes in XLA. Valid values are: " + "'minor2major', 'major2minor', 'random', 'random:'. " + "For debugging purposes. If no seed (or 0) is given, a seed from " + "random_device is used."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Parse text into *layout. +static bool ParseDefaultLayout(const string& text, DefaultLayout* layout) { + bool result = true; + std::vector field = tensorflow::str_util::Split(text, ':'); + if (field.size() > 0) { + if (field[0] == "random") { + layout->dimension_order = DefaultLayout::DimensionOrder::kRandom; + if (field.size() > 1) { + uint64 seed = 0; + result = tensorflow::strings::safe_strtou64(field[1], &seed); + layout->seed = seed; + } + } else if (field[0] == "minor2major") { + layout->dimension_order = DefaultLayout::DimensionOrder::kMinorToMajor; + } else if (field[0] == "major2minor") { + layout->dimension_order = DefaultLayout::DimensionOrder::kMajorToMinor; + } else { + result = false; + } + } + return result; +} + +// Pointer to the parsed value of the flags, initialized via flags_init. +static LayoutUtilFlags* flags; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + std::call_once(raw_flags_init, &AllocateRawFlag); + flags = new LayoutUtilFlags; + flags->xla_default_layout.dimension_order = + DefaultLayout::DimensionOrder::kMajorToMinor; + flags->xla_default_layout.seed = 0; + if (!ParseDefaultLayout(*raw_flag, &flags->xla_default_layout)) { + flags = nullptr; + } +} + +// Append to *append_to the flag definitions associated with XLA's layout_util +// module. +void AppendLayoutUtilFlags(std::vector* append_to) { + std::call_once(raw_flags_init, &AllocateRawFlag); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the LayoutUtilFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +LayoutUtilFlags* GetLayoutUtilFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h new file mode 100644 index 0000000000..177f428b73 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_ + +// Legacy flags for the XLA's layout_util module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// The default layout for all newly created shapes. Specified by the flag +// --xla_default_layout. +struct DefaultLayout { + enum class DimensionOrder { + kRandom, + kMinorToMajor, + kMajorToMinor, + }; + + DimensionOrder dimension_order; + size_t seed; +}; + +// Append to *flag_list the flag definitions associated with XLA's layout_util +// module. +void AppendLayoutUtilFlags(std::vector* flag_list); + +// The values of flags associated with XLA's layout_util module. +typedef struct { + // Default layout for Shapes in XLA. Valid values are: 'minor2major', + // 'major2minor', 'random', 'random:'. For debugging purposes. If no + // seed (or 0) is given, a seed from random_device is used. + DefaultLayout xla_default_layout; +} LayoutUtilFlags; + +// Return a pointer to the LayoutFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +LayoutUtilFlags* GetLayoutUtilFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc new file mode 100644 index 0000000000..c8a71b284f --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags associated with XLA's use of LLVM for code generation. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static LlvmBackendFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new LlvmBackendFlags; + flags->xla_fast_math = true; + flags->xla_precision_losing_optimizations = true; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_precision_losing_optimizations", + &flags->xla_precision_losing_optimizations, + "Allows llvm to make transformations that reduce the precision of " + "floating-point computations. This is equivalent to clang's " + "-funsafe-math-optimizations flag."), + tensorflow::Flag( + "xla_fast_math", &flags->xla_fast_math, + "Allows llvm to make all manner of unsafe floating-point " + "optimizations, including assuming that NaN and Inf don't appear. " + "This is equivalent to clang's -ffast-math flag."), + }); + ParseFlagsFromEnv(*flag_list); +} + +void AppendLlvmBackendFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +LlvmBackendFlags* GetLlvmBackendFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h new file mode 100644 index 0000000000..e8c0489285 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_BACKEND_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_BACKEND_FLAGS_H_ + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's use of LLVM for +// code generation. +void AppendLlvmBackendFlags(std::vector* flag_list); + +// The values of flags associated with XLA's use of LLVM for code generation. +typedef struct { + // Allows llvm to make transformations that reduce the precision of + // floating-point computations, but it *does not* allow it to disregard signed + // zero or assume that NaN and Inf never appear. + // + // Controls the "UnsafeFPMath" LLVM target option and + // llvm::FastMathFlags::allowReciprocal. This is equivalent to clang's + // -funsafe-math-optimizations flag. + bool xla_precision_losing_optimizations; + + // Unleashes the full power of LLVM's unsafe floating-point optimizations. + // Everything is fair game, including disregarding signed zero and assuming + // that NaN and Inf never appear. + // + // This implies xla_precision_losing_optimizations, and is equivalent to + // clang's -ffast-math flag. + bool xla_fast_math; +} LlvmBackendFlags; + +// Return a pointer to the LlvmBackendFlags struct. Repeated calls return the +// same pointer. This should be called only after Flags::Parse() has returned. +LlvmBackendFlags* GetLlvmBackendFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_BACKEND_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc new file mode 100644 index 0000000000..3c53729a67 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's llvm_util module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static LlvmUtilFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new LlvmUtilFlags; + flags->xla_emit_tbaa = true; + flag_list = new std::vector({ + tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa, + "Perform type-based alias analysis optimizations for " + "LLVM-based backends."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's llvm_util +// module. +void AppendLlvmUtilFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the LlvmUtilFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +LlvmUtilFlags* GetLlvmUtilFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h new file mode 100644 index 0000000000..98da26b4b8 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ + +// Legacy flags for XLA's llvm_util module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's llvm_util module. +void AppendLlvmUtilFlags(std::vector* flag_list); + +// The values of flags associated with XLA's llvm_util module. +typedef struct { + bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for + // LLVM-based backends. +} LlvmUtilFlags; + +// Return a pointer to the LlvmUtilFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +LlvmUtilFlags* GetLlvmUtilFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc new file mode 100644 index 0000000000..2a4e49b05a --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This module exports ParseFlagsFromEnv(), which allows other modules to parse +// flags from an environtment variable, or a file named by the environment +// variable. + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried +static const char kWS[] = " \t\r\n"; // whitespace + +// The following struct represents an argv[]-style array, parsed +// from data gleaned from the environment. +// +// As usual, an anonymous namespace is advisable to avoid +// constructor/destructor collisions with other "private" types +// in the same named namespace. +namespace { +struct EnvArgv { + EnvArgv() : initialized(false), argc(0) {} + bool initialized; // whether the other fields have been set. + int argc; // elements used in argv[] + std::vector argv; // flag arguments parsed from environment string. + std::vector argv_save; // saved values from argv[] to avoid leaks +}; +} // anonymous namespace + +// Append the string s0[0, .., s0len-1] concatenated with s1[0, .., s1len-1] as +// a newly allocated nul-terminated string to the array *a. If s0==nullptr, a +// nullptr is appended without increasing a->argc. +static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1, + size_t s1len, EnvArgv* a) { + if (s0 == nullptr) { + a->argv.push_back(nullptr); + a->argv_save.push_back(nullptr); + } else { + string s = string(s0, s0len) + string(s1, s1len); + char* str = strdup(s.c_str()); + a->argv.push_back(str); + a->argv_save.push_back(str); + a->argc++; + } +} + +// Like s.find_first_of(x, pos), but return s.size() when find_first_of() would +// return string::npos. This avoids if-statements elsewhere. +static size_t FindFirstOf(const string& s, const char* x, size_t pos) { + size_t result = s.find_first_of(x, pos); + return result == string::npos ? s.size() : result; +} + +// Like s.find_first_not_of(x, pos), but return s.size() when +// find_first_not_of() would return string::npos. This avoids if-statements +// elsewhere. +static size_t FindFirstNotOf(const string& s, const char* x, size_t pos) { + size_t result = s.find_first_not_of(x, pos); + return result == string::npos ? s.size() : result; +} + +// Given a string containing flags, parse them into the XLA command line flags. +// The parse is best effort, and gives up on the first syntax error. +static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { + size_t b = FindFirstNotOf(flag_str, kWS, 0); + while (b != flag_str.size() && flag_str[b] == '-') { + // b is the index of the start of a flag. + // Set e to the index just past the end of the flag. + size_t e = b; + while (e != flag_str.size() && isascii(flag_str[e]) && + (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { + e++; + } + if (e != flag_str.size() && flag_str[e] == '=' && + e + 1 != flag_str.size() && strchr("'\"", flag_str[e + 1]) != nullptr) { + // A flag of the form --flag="something in double or single quotes" + int c; + e++; // point just past '=' + size_t eflag = e; + char quote = flag_str[e]; + e++; // point just past quote + // Put in value the string with quotes removed. + string value; + for (; e != flag_str.size() && (c = flag_str[e]) != quote; e++) { + if (quote == '"' && c == '\\' && e + 1 != flag_str.size()) { + // Handle backslash in double quoted strings. They are literal in + // single-quoted strings. + e++; + c = flag_str[e]; + } + value += c; + } + if (e != flag_str.size()) { // skip final " or ' + e++; + } + AppendToEnvArgv(flag_str.data() + b, eflag - b, value.data(), + value.size(), a); + } else { // A flag without a quoted value. + e = FindFirstOf(flag_str, kWS, e); + AppendToEnvArgv(flag_str.data() + b, e - b, "", 0, a); + } + b = FindFirstNotOf(flag_str, kWS, e); + } +} + +// Call ParseArgvFromString(..., a) on a string derived from the setting of an +// environment variable kEnvVar, or a file it points to. +static void SetArgvFromEnv(EnvArgv* a) { + if (!a->initialized) { + static const char kDummyArgv[] = ""; + AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0, + a); // dummy argv[0] + const char* env = getenv(kEnvVar); + if (env == nullptr || env[0] == '\0') { + // nothing + } else if (env[strspn(env, kWS)] == '-') { // flags in env var value + ParseArgvFromString(env, a); + } else { // assume it's a file name + FILE* fp = fopen(env, "r"); + if (fp != nullptr) { + string str; + char buf[512]; + int n; + while ((n = fread(buf, 1, sizeof(buf), fp)) > 0) { + str.append(buf, n); + } + fclose(fp); + ParseArgvFromString(str, a); + } + } + AppendToEnvArgv(nullptr, 0, nullptr, 0, a); // add trailing nullptr to *a. + a->initialized = true; + } +} + +// The simulated argv[] parsed from the environment. +static EnvArgv* env_argv; + +// Used to protect accesses to env_argv. +static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED); + +// Call Flags::Parse(argc, argv, flag_list) against any as yet unrecognized +// flags passed in from the environment. +bool ParseFlagsFromEnv(const std::vector& flag_list) { + env_argv_mu.lock(); + if (env_argv == nullptr) { + env_argv = new EnvArgv; + } + SetArgvFromEnv(env_argv); // a no-op if already initialized + bool result = + tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); + env_argv_mu.unlock(); + return result; +} + +// Testing only. +// Reset the env_argv struct so that subsequent calls to ParseFlagsFromEnv() +// will parse the environment variable (or the file it points to) anew, and set +// *pargc, and *pargv to point to the internal locations of the argc and argv +// constructed from the environment. +void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { + env_argv_mu.lock(); + if (env_argv == nullptr) { + env_argv = new EnvArgv; + } + if (!env_argv->argv_save.empty()) { + for (int i = 0; env_argv->argv_save[i] != nullptr; i++) { + free(env_argv->argv_save[i]); + } + } + env_argv->initialized = false; + env_argv->argc = 0; + env_argv->argv.clear(); + env_argv->argv_save.clear(); + env_argv_mu.unlock(); + *pargc = &env_argv->argc; + *pargv = &env_argv->argv; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h new file mode 100644 index 0000000000..b54482ad2b --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ + +// This module exports ParseFlagsFromEnv(), which allows other modules to parse +// flags from the environtment variable TF_XLA_FLAGS, or (if the first +// non-whitespace in the variable value is not '-'), a file named by that +// environment variable. The accepted syntax is that flags arguments are of +// the form --flag=value or (for boolean flags) --flag, and are whitespace +// separated. The may be one of: +// - +// in which case the effective value is the string itself +// - in which case the effective value is the +// string with the single-quotes removed +// - in which case the effective value if the +// string with the double-quotes removed, and escaped sequences of +// replaced by . +// +// Flags values inconsistent with the type of the flag will be rejected by the +// flag parser. +// +// Examples: +// TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" +// +// TF_XLA_FLAGS=/tmp/flagfile +// where /tmp/flagfile might contain +// --some_flag="This is a string containing a \" and a '." +// --another_flag=wombats + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet +// unrecognized flags passed in from the environment, and return its +// return value. +bool ParseFlagsFromEnv(const std::vector& flag_list); + +// Used only for testing. Not to be used by clients. +void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc new file mode 100644 index 0000000000..7a966ce241 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test for parse_flags_from_env.cc + +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Test that XLA flags can be set from the environment. +// Failure messages are accompanied by the text in msg[]. +static void TestParseFlagsFromEnv(const char* msg) { + // Initialize module under test. + int* pargc; + std::vector* pargv; + ResetFlagsFromEnvForTesting(&pargc, &pargv); + + // Ensure that environment variable can be parsed when + // no flags are expected. + std::vector empty_flag_list; + bool parsed_ok = ParseFlagsFromEnv(empty_flag_list); + CHECK(parsed_ok) << msg; + const std::vector& argv_first = *pargv; + CHECK_NE(argv_first[0], nullptr) << msg; + int i = 0; + while (argv_first[i] != nullptr) { + i++; + } + CHECK_EQ(i, *pargc) << msg; + + // Check that actual flags can be parsed. + bool simple = false; + string with_value; + string embedded_quotes; + string single_quoted; + string double_quoted; + std::vector flag_list = { + tensorflow::Flag("simple", &simple, ""), + tensorflow::Flag("with_value", &with_value, ""), + tensorflow::Flag("embedded_quotes", &embedded_quotes, ""), + tensorflow::Flag("single_quoted", &single_quoted, ""), + tensorflow::Flag("double_quoted", &double_quoted, ""), + }; + parsed_ok = ParseFlagsFromEnv(flag_list); + CHECK_EQ(*pargc, 1) << msg; + const std::vector& argv_second = *pargv; + CHECK_NE(argv_second[0], nullptr) << msg; + CHECK_EQ(argv_second[1], nullptr) << msg; + CHECK(parsed_ok) << msg; + CHECK(simple) << msg; + CHECK_EQ(with_value, "a_value") << msg; + CHECK_EQ(embedded_quotes, "single'double\"") << msg; + CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg; + CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg; +} + +// The flags settings to test. +static const char kTestFlagString[] = + "--simple " + "--with_value=a_value " + "--embedded_quotes=single'double\" " + "--single_quoted='single quoted \\\\ \n \"' " + "--double_quoted=\"double quoted \\\\ \n '\\\"\" "; + +// Test that the environent variable is parserd correctly. +TEST(ParseFlagsFromEnv, Basic) { + // Prepare environment. + setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/); + TestParseFlagsFromEnv("(flags in environment variable)"); +} + +// Test that a file named by the environent variable is parserd correctly. +TEST(ParseFlagsFromEnv, File) { + // environment variables where tmp dir may be specified. + static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"}; + static const char kTempDir[] = "/tmp"; // default temp dir if all else fails. + const char* tmp_dir = nullptr; + for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) { + tmp_dir = getenv(kTempVars[i]); + } + if (tmp_dir == nullptr) { + tmp_dir = kTempDir; + } + string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d", + tmp_dir, getpid()); + FILE* fp = fopen(tmp_file.c_str(), "w"); + CHECK_NE(fp, nullptr) << "can't write to " << tmp_file; + for (int i = 0; kTestFlagString[i] != '\0'; i++) { + putc(kTestFlagString[i], fp); + } + fflush(fp); + CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file; + fclose(fp); + // Prepare environment. + setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/); + TestParseFlagsFromEnv("(flags in file)"); + unlink(tmp_file.c_str()); +} + +// Name of the test binary. +static const char* binary_name; + +// Test that when we use both the environment variable and actual +// commend line flags (when the latter is possible), the latter win. +TEST(ParseFlagsFromEnv, EnvAndFlag) { + // TODO(m3b): convert to Subprocess when CL 137771604 is finished. + static struct { + const char* env; + const char* arg; + const char* expected_value; + } test[] = { + {nullptr, nullptr, "1\n"}, + {nullptr, "--int_flag=2", "2\n"}, + {"--int_flag=3", nullptr, "3\n"}, + {"--int_flag=3", "--int_flag=2", "2\n"}, // flag beats environment + }; + for (int i = 0; i != TF_ARRAYSIZE(test); i++) { + if (test[i].env != nullptr) { + setenv("TF_XLA_FLAGS", test[i].env, true /*overwrite*/); + } + tensorflow::SubProcess child; + std::vector argv; + argv.push_back(binary_name); + argv.push_back("--recursing"); + if (test[i].arg != nullptr) { + argv.push_back(test[i].arg); + } + child.SetProgram(binary_name, argv); + child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + CHECK(child.Start()) << "test " << i; + string stdout_str; + int child_status = child.Communicate(nullptr, &stdout_str, nullptr); + CHECK_EQ(child_status, 0) << "test " << i; + CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i; + } +} + +} // namespace legacy_flags +} // namespace xla + +int main(int argc, char* argv[]) { + // Save name of binary so that it may invoke itself. + xla::legacy_flags::binary_name = argv[0]; + bool recursing = false; + xla::int32 int_flag = 1; + const std::vector flag_list = { + tensorflow::Flag("recursing", &recursing, + "Whether the binary is being invoked recusively."), + tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), + }; + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list); + if (!parse_ok) { + LOG(QFATAL) << "can't parse from environment\n" << usage; + } + parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_ok) { + LOG(QFATAL) << usage; + } + if (recursing) { + printf("%d\n", int_flag); + exit(0); + } + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.cc b/tensorflow/compiler/xla/legacy_flags/service_flags.cc new file mode 100644 index 0000000000..41cb8d8bdf --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/service_flags.cc @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's service module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static ServiceFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new ServiceFlags; + flags->xla_hlo_profile = false; + flags->xla_log_hlo_text = ""; + flags->xla_generate_hlo_graph = ""; + flags->xla_hlo_graph_addresses = false; + flags->xla_hlo_graph_layout = false; + flags->xla_hlo_graph_for_compute_constant = false; + flags->xla_dump_computations_to = ""; + flags->xla_dump_hlo_text_to = ""; + flags->xla_dump_executions_to = ""; + flag_list = new std::vector({ + tensorflow::Flag( + "xla_hlo_profile", &flags->xla_hlo_profile, + "Instrument the computation to collect per-HLO cycle counts"), + tensorflow::Flag( + "xla_log_hlo_text", &flags->xla_log_hlo_text, + "If non-empty, print the text format of " + "HLO modules whose name partially matches this regex. E.g. " + "xla_log_hlo_text=.* will dump the text for every module."), + tensorflow::Flag( + "xla_generate_hlo_graph", &flags->xla_generate_hlo_graph, + "If non-empty, dump graph of HLO modules whose name partially " + "matches this regex. E.g. --xla_generate_hlo_graph=.* will dump " + "the graph of every module."), + tensorflow::Flag("xla_hlo_graph_addresses", + &flags->xla_hlo_graph_addresses, + "Show addresses of HLO ops in graph"), + tensorflow::Flag("xla_hlo_graph_layout", &flags->xla_hlo_graph_layout, + "Show layout of HLO ops in graph"), + tensorflow::Flag( + "xla_hlo_graph_for_compute_constant", + &flags->xla_hlo_graph_for_compute_constant, + "If true, include hlo dumps of graphs from ComputeConstant." + "Such graphs still need to be matched via xla_generate_hlo_graph."), + tensorflow::Flag("xla_dump_computations_to", + &flags->xla_dump_computations_to, + "Dumps computations that XLA executes into the provided " + "directory path"), + tensorflow::Flag("xla_dump_hlo_text_to", &flags->xla_dump_hlo_text_to, + "Dumps HLO modules that XLA executes into the provided " + "directory path"), + tensorflow::Flag("xla_dump_executions_to", &flags->xla_dump_executions_to, + "Dumps parameters and results of computations that XLA " + "executes into the provided directory path"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's service module. +void AppendServiceFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the ServiceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +ServiceFlags* GetServiceFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.h b/tensorflow/compiler/xla/legacy_flags/service_flags.h new file mode 100644 index 0000000000..d982506944 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/service_flags.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_ + +// Legacy flags for XLA's service module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's service module. +void AppendServiceFlags(std::vector* flag_list); + +// The values of flags associated with XLA's service module. +typedef struct { + bool xla_hlo_profile; // Instrument the computation to collect per-HLO cycle + // counts + string xla_log_hlo_text; // If non-empty, print the text format of the HLO + // modules whose name partially + // matches this regex. E.g. xla_log_hlo_text=.* + // will dump the text for every module. + string xla_generate_hlo_graph; // If non-empty, dump graph of HLO modules + // whose name partially matches this regex. + // E.g. --xla_generate_hlo_graph=.* will dump + // the graph of every module. + bool xla_hlo_graph_addresses; // Show addresses of HLO ops in graph + bool xla_hlo_graph_layout; // Show layout of HLO ops in graph + bool xla_hlo_graph_for_compute_constant; // If true, include hlo dumps of + // graphs from ComputeConstant. + // Such graphs still need to be + // matched via + // xla_generate_hlo_graph. + string xla_dump_hlo_text_to; // Dumps HLO text for each HLO module that is + // executed into the provided directory path + string xla_dump_computations_to; // Dumps computations that XLA executes + // into the provided directory path + // Dumps parameters and results of computations that XLA executes into + // the provided directory path + string xla_dump_executions_to; +} ServiceFlags; + +// Return a pointer to the ServiceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +ServiceFlags* GetServiceFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc new file mode 100644 index 0000000000..6506175777 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's stream_assignment module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static StreamAssignmentFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new StreamAssignmentFlags; + flags->xla_gpu_disable_multi_streaming = false; + flag_list = new std::vector({ + tensorflow::Flag("xla_gpu_disable_multi_streaming", + &flags->xla_gpu_disable_multi_streaming, + "Disable multi-streaming in XLA's GPU backend"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's stream_assignment +// module. +void AppendStreamAssignmentFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the StreamAssignmentFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +StreamAssignmentFlags* GetStreamAssignmentFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h new file mode 100644 index 0000000000..a98f9b3458 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_ + +// Legacy flags for XLA's stream_assignment module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's stream_assignment +// module. +void AppendStreamAssignmentFlags(std::vector* flag_list); + +// The values of flags associated with XLA's stream_assignment module. +typedef struct { + bool xla_gpu_disable_multi_streaming; // Disable multi-streaming in XLA's GPU + // backend +} StreamAssignmentFlags; + +// Return a pointer to the StreamAssignmentFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +StreamAssignmentFlags* GetStreamAssignmentFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/util_flags.cc b/tensorflow/compiler/xla/legacy_flags/util_flags.cc new file mode 100644 index 0000000000..e6df19ddd2 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/util_flags.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for XLA's util module. + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/legacy_flags/util_flags.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static UtilFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new UtilFlags; + flags->xla_status_add_backtrace = false; + flag_list = new std::vector({ + tensorflow::Flag("xla_status_add_backtrace", + &flags->xla_status_add_backtrace, + "add backtraces to XLA-produced status values"), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's util module. +void AppendUtilFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the UtilFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +UtilFlags* GetUtilFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/util_flags.h b/tensorflow/compiler/xla/legacy_flags/util_flags.h new file mode 100644 index 0000000000..03bffcd726 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/util_flags.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_ + +// Legacy flags for the XLA's util module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flag definitions associated with XLA's util module. +void AppendUtilFlags(std::vector* flag_list); + +// The values of flags associated with XLA's util module. +typedef struct { + bool xla_status_add_backtrace; // add backtraces to XLA-produced statuses +} UtilFlags; + +// Return a pointer to the UtilFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +UtilFlags* GetUtilFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc new file mode 100644 index 0000000000..f2b2bf8cec --- /dev/null +++ b/tensorflow/compiler/xla/literal_util.cc @@ -0,0 +1,989 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { + switch (primitive_type) { + case U8: + return *LiteralUtil::CreateR0(0); + case U32: + return *LiteralUtil::CreateR0(0); + case U64: + return *LiteralUtil::CreateR0(0); + case S8: + return *LiteralUtil::CreateR0(0); + case S32: + return *LiteralUtil::CreateR0(0); + case S64: + return *LiteralUtil::CreateR0(0); + case F32: + return *LiteralUtil::CreateR0(0); + case F64: + return *LiteralUtil::CreateR0(0); + case PRED: + return *LiteralUtil::CreateR0(false); + case S16: + case U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case F16: + LOG(FATAL) << "f16 literals not yet implemented"; + case TUPLE: + LOG(FATAL) << "tuple element type cannot take on value of 0"; + case OPAQUE: + LOG(FATAL) << "opaque element type cannot take on value of 0"; + default: + LOG(FATAL) << "Unhandled primitive type " << primitive_type; + } +} + +/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { + switch (primitive_type) { + case U8: + return *LiteralUtil::CreateR0(1); + case U32: + return *LiteralUtil::CreateR0(1); + case U64: + return *LiteralUtil::CreateR0(1); + case S8: + return *LiteralUtil::CreateR0(1); + case S32: + return *LiteralUtil::CreateR0(1); + case S64: + return *LiteralUtil::CreateR0(1); + case F32: + return *LiteralUtil::CreateR0(1); + case F64: + return *LiteralUtil::CreateR0(1); + case PRED: + return *LiteralUtil::CreateR0(true); + case S16: + case U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case F16: + LOG(FATAL) << "f16 literals not yet implemented"; + case TUPLE: + LOG(FATAL) << "tuple element type cannot take on value of 1"; + case OPAQUE: + LOG(FATAL) << "opaque element type cannot take on value of 1"; + default: + LOG(FATAL) << "Unhandled primitive type " << primitive_type; + } +} + +/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { + switch (primitive_type) { + case U8: + return *LiteralUtil::CreateR0(std::numeric_limits::min()); + case U32: + return *LiteralUtil::CreateR0(std::numeric_limits::min()); + case U64: + return *LiteralUtil::CreateR0(std::numeric_limits::min()); + case S8: + return *LiteralUtil::CreateR0(std::numeric_limits::min()); + case S32: + return *LiteralUtil::CreateR0(std::numeric_limits::min()); + case S64: + return *LiteralUtil::CreateR0(std::numeric_limits::min()); + case F32: + return *LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); + case F64: + return *LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); + case PRED: + return *LiteralUtil::CreateR0(false); + case S16: + case U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case F16: + LOG(FATAL) << "f16 literals not yet implemented"; + case TUPLE: + LOG(FATAL) << "tuple element type has no minimum value"; + case OPAQUE: + LOG(FATAL) << "opaque element type has no minimum value"; + default: + LOG(FATAL) << "Unhandled primitive type " << primitive_type; + } +} + +/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { + switch (primitive_type) { + case U8: + return *LiteralUtil::CreateR0(std::numeric_limits::max()); + case U32: + return *LiteralUtil::CreateR0(std::numeric_limits::max()); + case U64: + return *LiteralUtil::CreateR0(std::numeric_limits::max()); + case S8: + return *LiteralUtil::CreateR0(std::numeric_limits::max()); + case S32: + return *LiteralUtil::CreateR0(std::numeric_limits::max()); + case S64: + return *LiteralUtil::CreateR0(std::numeric_limits::max()); + case F32: + return *LiteralUtil::CreateR0( + std::numeric_limits::infinity()); + case F64: + return *LiteralUtil::CreateR0( + std::numeric_limits::infinity()); + case PRED: + return *LiteralUtil::CreateR0(true); + case S16: + case U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case F16: + LOG(FATAL) << "f16 literals not yet implemented"; + case TUPLE: + LOG(FATAL) << "tuple element type has no maximum value"; + case OPAQUE: + LOG(FATAL) << "opaque element type has no maximum value"; + default: + LOG(FATAL) << "Unhandled primitive type " << primitive_type; + } +} + +/* static */ std::unique_ptr LiteralUtil::CreateR1( + const tensorflow::core::Bitmap& values) { + auto literal = MakeUnique(); + PopulateR1(values, literal.get()); + return literal; +} + +/* static */ std::unique_ptr LiteralUtil::CreateR1U8( + tensorflow::StringPiece value) { + auto literal = MakeUnique(); + *literal->mutable_shape() = + ShapeUtil::MakeShape(U8, {static_cast(value.size())}); + literal->set_u8s(value.ToString()); + return literal; +} + +/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( + float from, float to, int64 rows, int64 cols) { + auto value = MakeLinspaceArray2D(from, to, rows, cols); + return CreateR2FromArray2D(*value); +} + +/* static */ std::unique_ptr LiteralUtil::Relayout( + const Literal& original, const Layout& layout) { + // Note: if this were a performance bottleneck, we avoid cloning and just make + // an uninitialized array instead, since all values are clobbered below. + std::unique_ptr result = CloneToUnique(original); + *result->mutable_shape()->mutable_layout() = layout; + const PrimitiveType primitive_type = original.shape().element_type(); + switch (primitive_type) { + case F32: + LiteralUtil::EachCell( + original, + [&](tensorflow::gtl::ArraySlice indices, float value) { + LiteralUtil::Set(result.get(), indices, value); + }); + return result; + case S32: + LiteralUtil::EachCell( + original, + [&](tensorflow::gtl::ArraySlice indices, int32 value) { + LiteralUtil::Set(result.get(), indices, value); + }); + return result; + case U32: + LiteralUtil::EachCell( + original, + [&](tensorflow::gtl::ArraySlice indices, uint32 value) { + LiteralUtil::Set(result.get(), indices, value); + }); + return result; + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(primitive_type); + } +} + +/* static */ StatusOr> LiteralUtil::Reshape( + const xla::Literal& input, tensorflow::gtl::ArraySlice dimensions) { + if (ShapeUtil::IsTuple(input.shape())) { + return InvalidArgument("Reshape does not support tuples."); + } + + if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { + return Unimplemented( + "Input shape must have a monotonic layout where dimension 0 is major, " + "was: %s", + LayoutUtil::HumanString(input.shape().layout()).c_str()); + } + std::vector layout(dimensions.size()); + std::iota(layout.rbegin(), layout.rend(), 0); + + // Because the layout is monotonic, we can simply reuse the same sequence of + // values without changing their order. + std::unique_ptr output = CloneToUnique(input); + output->clear_shape(); + output->mutable_shape()->set_element_type(input.shape().element_type()); + for (int64 dimension : dimensions) { + output->mutable_shape()->add_dimensions(dimension); + } + *output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout); + + int64 elements_before = ShapeUtil::ElementsIn(input.shape()); + int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + if (elements_before != elements_after) { + return InvalidArgument( + "Shapes before and after LiteralUtil::Reshape have different numbers " + "of elements: %s vs %s.", + ShapeUtil::HumanString(input.shape()).c_str(), + ShapeUtil::HumanString(output->shape()).c_str()); + } + return std::move(output); +} + +/* static */ std::unique_ptr LiteralUtil::Transpose( + const Literal& original, tensorflow::gtl::ArraySlice permutation) { + CHECK(!ShapeUtil::IsTuple(original.shape())) + << "tuple is not supported for transpose"; + std::vector dimension_numbers(ShapeUtil::Rank(original.shape())); + std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0); + CHECK(std::is_permutation(permutation.begin(), permutation.end(), + dimension_numbers.begin())) + << "given permutation is not a permutation of dimension numbers"; + std::vector new_dimension_sizes; + for (const int64 dim : permutation) { + new_dimension_sizes.push_back(original.shape().dimensions(dim)); + } + const auto result_shape = ShapeUtil::MakeShape( + original.shape().element_type(), new_dimension_sizes); + std::unique_ptr result = CloneToUnique(original); + *result->mutable_shape() = result_shape; + const PrimitiveType primitive_type = original.shape().element_type(); + std::vector new_indices(ShapeUtil::Rank(original.shape())); + switch (primitive_type) { + case F32: + LiteralUtil::EachCell( + original, + [&](tensorflow::gtl::ArraySlice indices, float value) { + for (int64 i = 0; i < permutation.size(); ++i) { + new_indices[i] = indices[permutation[i]]; + } + LiteralUtil::Set(result.get(), new_indices, value); + }); + return result; + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(primitive_type); + } +} + +/* static */ std::unique_ptr LiteralUtil::Slice( + const Literal& literal, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) { + CHECK(!ShapeUtil::IsTuple(literal.shape())) + << "tuple is not supported for reshape"; + + std::vector result_dimensions; + for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) { + CHECK_GE(start_indices[dnum], 0); + CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum)); + int64 dimension = limit_indices[dnum] - start_indices[dnum]; + CHECK_GT(dimension, 0); + result_dimensions.push_back(dimension); + } + const auto result_shape = ShapeUtil::MakeShapeWithLayout( + literal.shape().element_type(), result_dimensions, + AsInt64Slice(literal.shape().layout().minor_to_major())); + + auto result_literal = MakeUnique(); + *result_literal->mutable_shape() = result_shape; + Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get()); + + std::vector new_indices(ShapeUtil::Rank(result_shape)); + switch (result_shape.element_type()) { + case F32: + LiteralUtil::EachCell( + *result_literal, + [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + float value = LiteralUtil::Get(literal, new_indices); + LiteralUtil::Set(result_literal.get(), indices, value); + }); + return result_literal; + case S32: + LiteralUtil::EachCell( + *result_literal, + [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + int32 value = LiteralUtil::Get(literal, new_indices); + LiteralUtil::Set(result_literal.get(), indices, value); + }); + return result_literal; + case U32: + LiteralUtil::EachCell( + *result_literal, + [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + uint32 value = LiteralUtil::Get(literal, new_indices); + LiteralUtil::Set(result_literal.get(), indices, value); + }); + return result_literal; + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(result_shape.element_type()); + } +} + +/* static */ std::unique_ptr LiteralUtil::CloneToUnique( + const Literal& literal) { + auto unique = MakeUnique(); + *unique = literal; + return unique; +} + +/* static */ string LiteralUtil::GetAsString( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + switch (literal.shape().element_type()) { + case PRED: + return Get(literal, multi_index) ? "true" : "false"; + case U8: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + case S32: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + case S64: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + case U32: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + case U64: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + case F32: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + case F64: + return tensorflow::strings::StrCat(Get(literal, multi_index)); + default: + return tensorflow::strings::StrCat( + "[", PrimitiveType_Name(literal.shape().element_type()), "]"); + } +} + +/* static */ int64 LiteralUtil::LinearIndex( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), + multi_index); +} + +/* static */ string LiteralUtil::ToString(const Literal& literal) { + const Shape& shape = literal.shape(); + std::vector pieces; + + auto element_to_string = + [&literal](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = literal.shape().element_type(); + if (element_type == PRED) { + // We display predicates in a densely packed form. + return Get(literal, indices) ? "1" : "0"; + } + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + + GetAsString(literal, indices); + }; + + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(shape)) { + pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(" (\n"); + for (const auto& element_literal : literal.tuple_literals()) { + pieces.push_back(ToString(element_literal)); + pieces.push_back(",\n"); + } + pieces.push_back(")"); + } else if (ShapeUtil::Rank(shape) == 0) { + pieces.push_back(GetAsString(literal, {})); + } else if (ShapeUtil::Rank(shape) == 1) { + pieces.push_back("{"); + for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + pieces.push_back(element_to_string({i0})); + } + pieces.push_back("}"); + } else if (ShapeUtil::Rank(shape) == 2) { + pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(" {\n"); + for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + pieces.push_back(" { "); + for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + pieces.push_back(element_to_string({i0, i1})); + } + pieces.push_back(" "); + pieces.push_back("},\n"); + } + pieces.push_back("}"); + } else if (ShapeUtil::Rank(shape) == 3) { + pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(" {\n"); + for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + pieces.push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + pieces.push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + pieces.push_back(element_to_string({i0, i1, i2})); + } + pieces.push_back(" }"); + } + pieces.push_back(" }"); + } + pieces.push_back("\n}"); + } else if (ShapeUtil::Rank(shape) == 4) { + pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(" {\n"); + for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); + for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + pieces.push_back( + tensorflow::strings::Printf(" { // i1=%lld\n", i1)); + for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + pieces.push_back(" {"); + for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + pieces.push_back(element_to_string({i0, i1, i2, i3})); + } + pieces.push_back("},\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back("}"); + } else if (ShapeUtil::Rank(shape) == 5) { + pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(" {\n"); + for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); + for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + pieces.push_back( + tensorflow::strings::Printf(" { // i1=%lld\n", i1)); + for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + pieces.push_back( + tensorflow::strings::Printf(" { // i2=%lld\n", i2)); + for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + pieces.push_back(" {"); + for (int64 i4 = 0; i4 < shape.dimensions(4); ++i4) { + pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); + } + pieces.push_back("},\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back(" },\n"); + } + pieces.push_back("}"); + } else { + pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(" {...}"); + } + + return tensorflow::str_util::Join(pieces, ""); +} + +/* static */ std::unique_ptr LiteralUtil::MakeTuple( + tensorflow::gtl::ArraySlice elements) { + auto literal = MakeUnique(); + std::vector shape; + for (const Literal* tuple_element : elements) { + *literal->add_tuple_literals() = *tuple_element; + shape.push_back(tuple_element->shape()); + } + *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); + return literal; +} + +/* static */ const void* LiteralUtil::InternalData(const Literal& literal) { + switch (literal.shape().element_type()) { + case PRED: + return reinterpret_cast(literal.preds().data()); + case U8: + return reinterpret_cast(literal.u8s().data()); + case S32: + return reinterpret_cast(literal.s32s().data()); + case S64: + return reinterpret_cast(literal.s64s().data()); + case U32: + return reinterpret_cast(literal.u32s().data()); + case U64: + return reinterpret_cast(literal.u64s().data()); + case F32: + return reinterpret_cast(literal.f32s().data()); + case F64: + return reinterpret_cast(literal.f64s().data()); + default: + LOG(FATAL) << "primitive type not supported in literals: " + << PrimitiveType_Name(literal.shape().element_type()); + } +} + +/* static */ void* LiteralUtil::MutableInternalData(Literal* literal) { + return const_cast(LiteralUtil::InternalData(*literal)); +} + +/* static */ void LiteralUtil::Reserve(int64 num_elements, Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + switch (literal->shape().element_type()) { + case PRED: + GetMutableRepeatedField(literal)->Resize(num_elements, false); + break; + case U8: + // u8s is an optional "bytes", rather than a repeated field. Therefore its + // access methods are somewhat different from the others. + literal->mutable_u8s()->resize(num_elements, 0); + break; + case S32: + GetMutableRepeatedField(literal)->Resize(num_elements, + /*value=*/0); + break; + case S64: + GetMutableRepeatedField(literal)->Resize( + num_elements, + /*value=*/0); + break; + case U32: + GetMutableRepeatedField(literal)->Resize(num_elements, + /*value=*/0); + break; + case U64: + GetMutableRepeatedField(literal)->Resize( + num_elements, + /*value=*/0); + break; + case F32: + GetMutableRepeatedField(literal)->Resize(num_elements, + /*value=*/0.0f); + break; + case F64: + GetMutableRepeatedField(literal)->Resize(num_elements, + /*value=*/0.0); + break; + default: + LOG(FATAL) << "primitive type not supported in literals: " + << PrimitiveType_Name(literal->shape().element_type()); + } +} + +/* static */ tensorflow::Status LiteralUtil::ValidateLiteral( + const Literal& literal) { + TF_CHECK_OK(ShapeUtil::ValidateShape(literal.shape())); + int64 expected = ShapeUtil::ElementsIn(literal.shape()); + int64 actual = -1; + switch (literal.shape().element_type()) { + case PRED: + actual = literal.preds().size(); + break; + case U8: + actual = literal.u8s().size(); + break; + case S32: + actual = literal.s32s_size(); + break; + case U32: + actual = literal.u32s_size(); + break; + case S64: + actual = literal.s64s_size(); + break; + case U64: + actual = literal.u64s_size(); + break; + case F32: + actual = literal.f32s_size(); + break; + case F64: + actual = literal.f64s_size(); + break; + default: + return tensorflow::errors::Unimplemented( + "unhandled element type for literal validation: " + + PrimitiveType_Name(literal.shape().element_type())); + } + + if (expected != actual) { + return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf( + "literal has bad number of elements for its shape %s: want %lld " + "got %lld", + ShapeUtil::HumanString(literal.shape()).c_str(), expected, actual)); + } + + return tensorflow::Status::OK(); +} + +/* static */ void LiteralUtil::EachCellAsString( + const Literal& literal, + std::function indices, + const string& value)> + per_cell) { + if (ShapeUtil::Rank(literal.shape()) == 1) { + for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { + per_cell({i0}, GetAsString(literal, {i0})); + } + return; + } + + if (ShapeUtil::Rank(literal.shape()) == 2) { + for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { + for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { + per_cell({i0, i1}, GetAsString(literal, {i0, i1})); + } + } + return; + } + + if (ShapeUtil::Rank(literal.shape()) == 3) { + for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { + for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { + for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { + per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2})); + } + } + } + return; + } + + if (ShapeUtil::Rank(literal.shape()) == 4) { + for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { + for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { + for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { + for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) { + per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3})); + } + } + } + } + return; + } + + LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape()); +} + +namespace { + +// Helper function which compares whether the elements of literal1 are equal to +// the elements of literal2. Recursively iterates through the entire +// multidimensional index space and compares the literal elements +// one-by-one. literal1 and literal2 must be compatible (same dimensions and +// type). +template +bool EqualElements(const Literal& literal1, const Literal& literal2, + int dimension, std::vector* multi_index) { + if (dimension == ShapeUtil::Rank(literal1.shape())) { + return (LiteralUtil::Get(literal1, *multi_index) == + LiteralUtil::Get(literal2, *multi_index)); + } + for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + if (!EqualElements(literal1, literal2, dimension + 1, + multi_index)) { + return false; + } + } + return true; +} + +} // namespace + +/* static */ bool LiteralUtil::Equal(const Literal& literal1, + const Literal& literal2) { + if (!ShapeUtil::Compatible(literal1.shape(), literal2.shape())) { + return false; + } + if (ShapeUtil::IsTuple(literal1.shape())) { + // Because the shapes are compatible, they must have the same number of + // tuple elements. + CHECK_EQ(literal1.tuple_literals_size(), literal2.tuple_literals_size()); + for (int i = 0; i < literal1.tuple_literals_size(); ++i) { + if (!Equal(literal1.tuple_literals(i), literal2.tuple_literals(i))) { + return false; + } + } + return true; + } else { + std::vector multi_index(ShapeUtil::Rank(literal1.shape()), 0); + switch (literal1.shape().element_type()) { + case PRED: + return EqualElements(literal1, literal2, 0, &multi_index); + case U8: + return EqualElements(literal1, literal2, 0, &multi_index); + case S32: + return EqualElements(literal1, literal2, 0, &multi_index); + case S64: + return EqualElements(literal1, literal2, 0, &multi_index); + case U32: + return EqualElements(literal1, literal2, 0, &multi_index); + case U64: + return EqualElements(literal1, literal2, 0, &multi_index); + case F32: + return EqualElements(literal1, literal2, 0, &multi_index); + case F64: + return EqualElements(literal1, literal2, 0, &multi_index); + default: + LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " + << PrimitiveType_Name(literal1.shape().element_type()); + } + } +} + +template <> +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal) { + CHECK(literal.shape().element_type() == PRED); + return literal.preds(); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal) { + CHECK(literal->shape().element_type() == PRED); + return literal->mutable_preds(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK(literal.shape().element_type() == U32); + return literal.u32s(); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal) { + CHECK(literal->shape().element_type() == U32); + return literal->mutable_u32s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK(literal.shape().element_type() == U64); + return AsUInt64Slice(literal.u64s()); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField( + Literal* literal) { + CHECK(literal->shape().element_type() == U64); + return literal->mutable_u64s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK(literal.shape().element_type() == S32); + return literal.s32s(); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal) { + CHECK(literal->shape().element_type() == S32); + return literal->mutable_s32s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK(literal.shape().element_type() == S64); + return AsInt64Slice(literal.s64s()); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField( + Literal* literal) { + CHECK(literal->shape().element_type() == S64); + return literal->mutable_s64s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK(literal.shape().element_type() == F32); + return literal.f32s(); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal) { + CHECK(literal->shape().element_type() == F32); + return literal->mutable_f32s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK(literal.shape().element_type() == F64); + return literal.f64s(); +} + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal) { + CHECK(literal->shape().element_type() == F64); + return literal->mutable_f64s(); +} + +template +static bool AllElementsEqualValue(const Literal& literal, NativeT value) { + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + auto multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + if (LiteralUtil::Get(literal, multi_index) != value) { + return false; + } + } + return true; +} + +/* static */ bool LiteralUtil::IsAll(const Literal& literal, int8 value) { + switch (literal.shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue(literal, value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue(literal, value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue(literal, value); + } + return false; + case S8: + return AllElementsEqualValue(literal, value); + case S32: + return AllElementsEqualValue(literal, value); + case S64: + return AllElementsEqualValue(literal, value); + case F32: + return AllElementsEqualValue(literal, value); + case F64: + return AllElementsEqualValue(literal, value); + case PRED: + if (value == 0) { + return AllElementsEqualValue(literal, false); + } + if (value == 1) { + return AllElementsEqualValue(literal, true); + } + return false; + default: + return false; + } +} + +/* static */ bool LiteralUtil::IsZero( + const Literal& literal, tensorflow::gtl::ArraySlice indices) { + switch (literal.shape().element_type()) { + case U8: + return Get(literal, indices) == 0; + case U32: + return Get(literal, indices) == 0; + case U64: + return Get(literal, indices) == 0; + case S8: + return Get(literal, indices) == 0; + case S32: + return Get(literal, indices) == 0; + case S64: + return Get(literal, indices) == 0; + case F32: + return Get(literal, indices) == 0.0f; + case F64: + return Get(literal, indices) == 0.0; + case PRED: + return Get(literal, indices) == false; + default: + LOG(FATAL) << "Input literal must be an array."; + } +} + +template <> +/* static */ void LiteralUtil::PopulateWithValue( + int64 value, tensorflow::gtl::ArraySlice dimensions, + Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), dimensions); + tensorflow::protobuf::RepeatedField* + repeated_field = + GetMutableRepeatedField(literal); + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { + repeated_field->Add(value); + } +} + +template <> +/* static */ void LiteralUtil::PopulateWithValue( + uint64 value, tensorflow::gtl::ArraySlice dimensions, + Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), dimensions); + tensorflow::protobuf::RepeatedField* + repeated_field = + GetMutableRepeatedField(literal); + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { + repeated_field->Add(value); + } +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + tensorflow::protobuf::RepeatedField* + repeated_field = + GetMutableRepeatedField(literal); + repeated_field->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + tensorflow::protobuf::RepeatedField* + repeated_field = + GetMutableRepeatedField(literal); + repeated_field->Resize(num_elements, value); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h new file mode 100644 index 0000000000..78e9e3fb24 --- /dev/null +++ b/tensorflow/compiler/xla/literal_util.h @@ -0,0 +1,1004 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for dealing with Literal protobufs. + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Utility class for dealing with XLA literal values. Most methods are +// templated by native (host) type which corresponds to a unique XLA +// PrimitiveType. See ComputationBuilder for details. Not all primitive types +// defined in xla_data.proto have a corresponding native type or even have a +// storage location in the Literal proto yet (for example, primitive type F16). +class LiteralUtil { + public: + // Create new literal of a given rank. To minimize ambiguity (for users and + // the compiler) these CreateR[0-2] methods should explicitly specify the + // native type. For example: + // + // CreateR1({1.0, 42.0}); + // CreateR2({{1, 2}, {3, 4}}); + // + // The variants not ending with WithLayout use the default XLA layout for the + // literal's linear representation in memory. + template + static std::unique_ptr CreateR0(NativeT value); + template + static std::unique_ptr CreateR1( + tensorflow::gtl::ArraySlice values); + static std::unique_ptr CreateR1( + const tensorflow::core::Bitmap& values); + template + static std::unique_ptr CreateR2( + std::initializer_list> values); + template + static std::unique_ptr CreateR2WithLayout( + std::initializer_list> values, + const Layout& layout); + template + static std::unique_ptr CreateR3( + std::initializer_list< + std::initializer_list>> + values); + template + static std::unique_ptr CreateR3WithLayout( + std::initializer_list< + std::initializer_list>> + values, + const Layout& layout); + template + static std::unique_ptr CreateR4( + std::initializer_list>>> + values); + template + static std::unique_ptr CreateR4WithLayout( + std::initializer_list>>> + values, + const Layout& layout); + + // Creates a new value that has the equivalent value as literal, but conforms + // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major + // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension + // layout and the value in the cell at any given logical index (i0, i1) will + // be the same. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + static std::unique_ptr Relayout(const Literal& literal, + const Layout& new_layout); + + // Reshapes literal 'input' to have 'shape'. Both the original shape and + // 'shape' must contain the same number of elements. The implementation + // currently only supports monotonic dim0-major layouts. + static StatusOr> Reshape( + const xla::Literal& input, tensorflow::gtl::ArraySlice shape); + + // Creates a new literal by reordering the dimensions of the original literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + static std::unique_ptr Transpose( + const Literal& literal, tensorflow::gtl::ArraySlice permutation); + + // Creates a sub-array from the the given literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + static std::unique_ptr Slice( + const Literal& literal, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices); + + // Create a literal by converting each element in an original literal to a new + // type. + template + static std::unique_ptr Convert(const Literal& literal); + + // Create a literal value zero of the given primitive type. + static Literal Zero(PrimitiveType primitive_type); + + // Create a literal value one of the given primitive type. + static Literal One(PrimitiveType primitive_type); + + // Creates a literal value containing the minimum value of the given + // primitive type. For floating-point types, returns -inf. + static Literal MinValue(PrimitiveType primitive_type); + + // Create a literal value containing the maximum value of the given + // primitive type. For floating-point types, returns inf. + static Literal MaxValue(PrimitiveType primitive_type); + + // Create a literal of the given shape where each element is `value`. + template + static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + tensorflow::gtl::ArraySlice dimensions, NativeT value); + + // Create a new literal from an array. The variants not ending with WithLayout + // use the default XLA layout for the literal's linear representation in + // memory. + template + static std::unique_ptr CreateR2FromArray2D( + const Array2D& values); + template + static std::unique_ptr CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout); + template + static std::unique_ptr CreateR3FromArray3D( + const Array3D& values); + template + static std::unique_ptr CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout); + template + static std::unique_ptr CreateR4FromArray4D( + const Array4D& values); + template + static std::unique_ptr CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout); + + // Creates a new vector of U8s literal value from a string. + static std::unique_ptr CreateR1U8(tensorflow::StringPiece value); + + // Creates a linspace-populated literal with the given number of rows and + // columns. + static std::unique_ptr CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols); + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z dimension given by "projection". + template + static std::unique_ptr CreateR3Projected( + std::initializer_list> values, + int64 projection); + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z); + + // Clones literal into an owned unique_ptr version. + static std::unique_ptr CloneToUnique(const Literal& literal); + + // Gets or sets an element in the literal at the given index. The index is + // CHECKed against the dimension sizes. + template + static NativeT Get(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index); + template + static void Set(Literal* literal, + tensorflow::gtl::ArraySlice multi_index, + NativeT value); + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + static NativeT GetFirstElement(const Literal& literal); + + // As Get(), but determines the correct type and converts the value + // into text. + static string GetAsString(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index); + + // Returns an identity matrix (rank 2) with the given row and column count. + template + static std::unique_ptr MakeIdentityR2(int64 size); + + // Returns a tuple literal composed of given literals. + static std::unique_ptr MakeTuple( + tensorflow::gtl::ArraySlice elements); + + // Validates that the data payload of the literal matches the literal shape; + // if it does not, an appropriate status is returned. + static tensorflow::Status ValidateLiteral(const Literal& literal); + + // Returns a string representation of the literal value. + static string ToString(const Literal& literal); + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + static void EachCellAsString( + const Literal& literal, + std::function indices, + const string& value)> + per_cell); + template + static void EachCell( + const Literal& literal, + std::function indices, + NativeT value)> + per_cell); + + // Templated methods which populate the given repeated field in the Literal + // proto with the given value(s). The Shape field of the Literal proto is set + // to match the array dimensions and type. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // PopulateR2FromArray2D(values, literal); + // + // // Populate with int32s. + // PopulateR2({{1, 2}, {3, 4}}, literal); + // + template + static void PopulateR0(NativeT values, Literal* literal); + template + static void PopulateR1(tensorflow::gtl::ArraySlice values, + Literal* literal); + static void PopulateR1(const tensorflow::core::Bitmap& values, + Literal* literal); + template + static void PopulateR2( + std::initializer_list> values, + Literal* literal); + template + static void PopulateR2WithLayout( + std::initializer_list> values, + const Layout& layout, Literal* literal); + template + static void PopulateR2FromArray2D(const Array2D& values, + Literal* literal); + template + static void PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout, + Literal* literal); + template + static void PopulateR3FromArray3D(const Array3D& values, + Literal* literal); + template + static void PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout, + Literal* literal); + template + static void PopulateR4FromArray4D(const Array4D& values, + Literal* literal); + template + static void PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout, + Literal* literal); + + // Creates a Literal of the given dimensions with all elements set to the + // given value. + template + static void PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions, + Literal* literal); + + // Returns a pointer to the underlying buffer in the protobuf containing the + // array data. Use with care. + static const void* InternalData(const Literal& literal); + static void* MutableInternalData(Literal* literal); + + // Allocates space in the repeated_field of the literal sufficient to hold + // num_elements of the literal's primitive type. Values in the buffer are set + // to zero. num_elements must equal the number of elements in the literals + // shape. + static void Reserve(int64 num_elements, Literal* literal); + + // Allocates space in the repeated_field of the literal sufficient to hold + // num_elements of the literal's primitive type and sets each element in the + // literal to the given value. num_elements must equal the number of elements + // in the literals shape. + template + static void Resize(int64 num_elements, NativeT value, Literal* literal); + + // Returns true if the two given literals have the same shape and + // values. Layout is not considered in the comparison. + static bool Equal(const Literal& literal1, const Literal& literal2); + + // Returns whether every element in the given literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in literal's type, returns false. Values of 1/0 are + // considered equal to true/false; other values are not considered equal to + // true. + static bool IsAll(const Literal& literal, int8 value); + + // Returns whether the literal is zero at the specified index. The literal + // must be an array. + static bool IsZero(const Literal& literal, + tensorflow::gtl::ArraySlice indices); + + private: + // Returns an ArraySlice view of the array for the given literal for the + // given NativeT (e.g., float). These + // functions map native type to XLA PrimitiveType via template + // specialization. The unspecialized forms below aborts to handle the error + // case where the given native type does not map to an XLA primitive type. + template + static tensorflow::gtl::ArraySlice GetArraySlice( + const Literal& literal) { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } + template + static tensorflow::protobuf::RepeatedField* GetMutableRepeatedField( + Literal* literal) { + // Make the expression depend on the template parameter NativeT so + // that this compile-time error only apperas if this function is + // instantiated with some concrete type that is not specialized + // below. + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } + + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index); + + TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); +}; + +// Declarations of template specializations for GetArraySlice and +// GetMutableRepeatedField. The specializations map native type to XLA primitive +// type. +template <> +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal); + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal); + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField( + Literal* literal); + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal); + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField( + Literal* literal); + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal); + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::protobuf::RepeatedField* +LiteralUtil::GetMutableRepeatedField(Literal* literal); + +template +/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { + auto literal = MakeUnique(); + PopulateR0(value, literal.get()); + return literal; +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR1( + tensorflow::gtl::ArraySlice values) { + auto literal = MakeUnique(); + PopulateR1(values, literal.get()); + return literal; +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( + std::initializer_list> values, + const Layout& layout) { + auto literal = MakeUnique(); + PopulateR2WithLayout(values, layout, literal.get()); + return literal; +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR2( + std::initializer_list> values) { + return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( + std::initializer_list>> + values, + const Layout& layout) { + const int64 d0 = values.size(); + const int64 d1 = values.begin()->size(); + const int64 d2 = values.begin()->begin()->size(); + Array3D tmp(d0, d1, d2); + int64 i0 = 0; + for (auto d1_values : values) { + int64 i1 = 0; + for (auto d2_values : d1_values) { + int64 i2 = 0; + for (auto value : d2_values) { + tmp(i0, i1, i2) = value; + ++i2; + } + ++i1; + } + ++i0; + } + return CreateR3FromArray3DWithLayout(tmp, layout); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR3( + std::initializer_list>> + values) { + return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( + std::initializer_list>>> + values, + const Layout& layout) { + const int64 d0 = values.size(); + const int64 d1 = values.begin()->size(); + const int64 d2 = values.begin()->begin()->size(); + const int64 d3 = values.begin()->begin()->begin()->size(); + Array4D tmp(d0, d1, d2, d3); + int64 i0 = 0; + for (auto d1_values : values) { + int64 i1 = 0; + for (auto d2_values : d1_values) { + int64 i2 = 0; + for (auto d3_values : d2_values) { + int64 i3 = 0; + for (auto value : d3_values) { + tmp(i0, i1, i2, i3) = value; + ++i3; + } + ++i2; + } + ++i1; + } + ++i0; + } + return CreateR4FromArray4DWithLayout(tmp, layout); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR4( + std::initializer_list>>> + values) { + return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); +} + +template +/* static */ std::unique_ptr +LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + auto literal = MakeUnique(); + PopulateR2FromArray2DWithLayout(values, layout, literal.get()); + return literal; +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( + const Array2D& values) { + return CreateR2FromArray2DWithLayout(values, + LayoutUtil::GetDefaultLayoutForR2()); +} +template +/* static */ std::unique_ptr +LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { + auto literal = MakeUnique(); + PopulateR3FromArray3DWithLayout(values, layout, literal.get()); + return literal; +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( + const Array3D& values) { + return CreateR3FromArray3DWithLayout(values, + LayoutUtil::GetDefaultLayoutForR3()); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( + std::initializer_list> values, + int64 projection) { + int64 dim0_size = projection; + int64 dim1_size = values.size(); + int64 dim2_size = values.begin()->size(); + + Array3D array(dim0_size, dim1_size, dim2_size); + for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { + int64 dim1 = 0; + for (auto inner_list : values) { + int64 dim2 = 0; + for (auto value : inner_list) { + array(dim0, dim1, dim2) = value; + ++dim2; + } + CHECK_EQ(dim2_size, dim2); + ++dim1; + } + CHECK_EQ(dim1_size, dim1); + } + return CreateR3FromArray3D(array); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z) { + int64 dim0_size = projection_p; + int64 dim1_size = projection_z; + int64 dim2_size = values.size(); + int64 dim3_size = values.begin()->size(); + + Array4D array(dim0_size, dim1_size, dim2_size, dim3_size); + for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { + for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { + int64 dim2 = 0; + for (auto inner_list : values) { + int64 dim3 = 0; + for (auto value : inner_list) { + array(dim0, dim1, dim2, dim3) = value; + ++dim3; + } + CHECK_EQ(dim3_size, dim3); + ++dim2; + } + CHECK_EQ(dim2_size, dim2); + } + } + return CreateR4FromArray4D(array); +} + +template +/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( + const Array4D& values) { + return CreateR4FromArray4DWithLayout(values, + LayoutUtil::GetDefaultLayoutForR4()); +} + +template +/* static */ std::unique_ptr +LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { + auto literal = MakeUnique(); + PopulateR4FromArray4DWithLayout(values, layout, literal.get()); + return literal; +} + +template +/* static */ NativeT LiteralUtil::Get( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + int64 linear_index = LinearIndex(literal, multi_index); + return GetArraySlice(literal).at(linear_index); +} + +template +/* static */ NativeT LiteralUtil::GetFirstElement(const Literal& literal) { + return GetArraySlice(literal).at(0); +} + +template <> +/* static */ inline uint8 LiteralUtil::Get( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + CHECK(literal.shape().element_type() == U8); + int64 linear_index = LinearIndex(literal, multi_index); + return literal.u8s()[linear_index]; +} + +template <> +/* static */ inline int8 LiteralUtil::Get( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + CHECK(literal.shape().element_type() == S8); + int64 linear_index = LinearIndex(literal, multi_index); + return literal.u8s()[linear_index]; +} + +template +/* static */ void LiteralUtil::Set( + Literal* literal, tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + int64 linear_index = LinearIndex(*literal, multi_index); + GetMutableRepeatedField(literal)->Set(linear_index, value); +} + +template <> +/* static */ inline void LiteralUtil::Set( + Literal* literal, tensorflow::gtl::ArraySlice multi_index, + uint8 value) { + int64 linear_index = LinearIndex(*literal, multi_index); + (*literal->mutable_u8s())[linear_index] = value; +} + +template <> +/* static */ inline void LiteralUtil::Set( + Literal* literal, tensorflow::gtl::ArraySlice multi_index, + int8 value) { + return Set(literal, multi_index, value); +} + +template <> +/* static */ inline void LiteralUtil::Set( + Literal* literal, tensorflow::gtl::ArraySlice multi_index, + int64 value) { + int64 linear_index = LinearIndex(*literal, multi_index); + (*literal->mutable_s64s())[linear_index] = value; +} + +template <> +/* static */ inline void LiteralUtil::Set( + Literal* literal, tensorflow::gtl::ArraySlice multi_index, + uint64 value) { + int64 linear_index = LinearIndex(*literal, multi_index); + (*literal->mutable_u64s())[linear_index] = value; +} + +// Returns an identity matrix (rank 2) with the given row and column count. +template +/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { + Array2D array(size, size, 0); + for (int64 i = 0; i < size; ++i) { + array(i, i) = 1; + } + return CreateR2FromArray2D(array); +} + +template +/* static */ void LiteralUtil::EachCell( + const Literal& literal, + std::function indices, + NativeT value)> + per_cell) { + if (ShapeUtil::HasZeroElements(literal.shape())) { + return; + } + std::vector indices(ShapeUtil::Rank(literal.shape()), 0); + do { + per_cell(indices, Get(literal, indices)); + } while (IndexUtil::BumpIndices(literal.shape(), &indices)); +} + +template +/* static */ void LiteralUtil::PopulateR0(NativeT value, Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {}); + tensorflow::protobuf::RepeatedField* repeated_field = + GetMutableRepeatedField(literal); + repeated_field->Add(value); +} + +template <> +/* static */ inline void LiteralUtil::PopulateR0(uint8 value, + Literal* literal) { + *literal->mutable_shape() = + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); + literal->mutable_u8s()->push_back(value); +} + +template <> +/* static */ inline void LiteralUtil::PopulateR0(int8 value, + Literal* literal) { + *literal->mutable_shape() = + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); + literal->mutable_u8s()->push_back(value); +} + +template <> +/* static */ inline void LiteralUtil::PopulateR0(uint64 value, + Literal* literal) { + *literal->mutable_shape() = + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); + literal->mutable_u64s()->Add(value); +} + +template <> +/* static */ inline void LiteralUtil::PopulateR0(int64 value, + Literal* literal) { + *literal->mutable_shape() = + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); + literal->mutable_s64s()->Add(value); +} + +template +/* static */ void LiteralUtil::PopulateR1( + tensorflow::gtl::ArraySlice values, Literal* literal) { + *literal->mutable_shape() = + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + {static_cast(values.size())}); + Reserve(values.size(), literal); + for (int64 i = 0; i < values.size(); ++i) { + Set(literal, {i}, values[i]); + } +} + +/* static */ inline void LiteralUtil::PopulateR1( + const tensorflow::core::Bitmap& values, Literal* literal) { + *literal->mutable_shape() = + ShapeUtil::MakeShape(PRED, {static_cast(values.bits())}); + Reserve(values.bits(), literal); + for (int64 i = 0; i < values.bits(); ++i) { + Set(literal, {i}, values.get(i)); + } +} + +template +/* static */ void LiteralUtil::PopulateR2WithLayout( + std::initializer_list> values, + const Layout& layout, Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), + {static_cast(values.size()), + static_cast(values.begin()->size())}, + AsInt64Slice(layout.minor_to_major())); + + const int64 dim0_size = values.size(); + const int64 dim1_size = values.begin()->size(); + CHECK_EQ(dim0_size, literal->shape().dimensions(0)); + CHECK_EQ(dim1_size, literal->shape().dimensions(1)); + + const int64 num_elements = dim1_size * dim0_size; + Reserve(num_elements, literal); + + int64 dim0 = 0; + for (auto inner_list : values) { + int64 dim1 = 0; + for (auto value : inner_list) { + Set(literal, {dim0, dim1}, value); + ++dim1; + } + CHECK_EQ(dim1_size, dim1); + ++dim0; + } +} + +template +/* static */ void LiteralUtil::PopulateR2( + std::initializer_list> values, + Literal* literal) { + PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), literal); +} + +template +/* static */ void LiteralUtil::PopulateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout, Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), + {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); + + const int64 dim1_size = values.width(); + const int64 dim0_size = values.height(); + CHECK_EQ(dim0_size, literal->shape().dimensions(0)); + CHECK_EQ(dim1_size, literal->shape().dimensions(1)); + Reserve(dim1_size * dim0_size, literal); + for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { + for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { + Set(literal, {dim0, dim1}, values(dim0, dim1)); + } + } +} + +template +/* static */ void LiteralUtil::PopulateR2FromArray2D( + const Array2D& values, Literal* literal) { + PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), + literal); +} +template +/* static */ void LiteralUtil::PopulateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout, Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), + {values.n1(), values.n2(), values.n3()}, + AsInt64Slice(layout.minor_to_major())); + + CHECK_EQ(values.n1(), literal->shape().dimensions(0)); + CHECK_EQ(values.n2(), literal->shape().dimensions(1)); + CHECK_EQ(values.n3(), literal->shape().dimensions(2)); + Reserve(values.n1() * values.n2() * values.n3(), literal); + for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { + for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { + for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { + Set(literal, {dim0, dim1, dim2}, values(dim0, dim1, dim2)); + } + } + } +} + +template +/* static */ void LiteralUtil::PopulateR3FromArray3D( + const Array3D& values, Literal* literal) { + PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3(), + literal); +} + +template +/* static */ void LiteralUtil::PopulateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout, Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), + {values.planes(), values.depth(), values.height(), values.width()}, + AsInt64Slice(layout.minor_to_major())); + + CHECK_EQ(values.n1(), literal->shape().dimensions(0)); + CHECK_EQ(values.n2(), literal->shape().dimensions(1)); + CHECK_EQ(values.n3(), literal->shape().dimensions(2)); + CHECK_EQ(values.n4(), literal->shape().dimensions(3)); + Reserve(values.n1() * values.n2() * values.n3() * values.n4(), literal); + for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { + for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { + for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { + for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { + Set(literal, {dim0, dim1, dim2, dim3}, + values(dim0, dim1, dim2, dim3)); + } + } + } + } +} + +template +/* static */ void LiteralUtil::PopulateR4FromArray4D( + const Array4D& values, Literal* literal) { + PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4(), + literal); +} + +template +/* static */ void LiteralUtil::PopulateWithValue( + NativeT value, tensorflow::gtl::ArraySlice dimensions, + Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), dimensions); + tensorflow::protobuf::RepeatedField* repeated_field = + GetMutableRepeatedField(literal); + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { + repeated_field->Add(value); + } +} + +template <> +/* static */ void LiteralUtil::PopulateWithValue( + int64 value, tensorflow::gtl::ArraySlice dimensions, + Literal* literal); + +template <> +/* static */ void LiteralUtil::PopulateWithValue( + uint64 value, tensorflow::gtl::ArraySlice dimensions, + Literal* literal); + +template +/* static */ std::unique_ptr LiteralUtil::Convert( + const Literal& literal) { + auto result_literal = MakeUnique(); + Shape result_shape = literal.shape(); + result_shape.set_element_type( + primitive_util::NativeToPrimitiveType()); + *result_literal->mutable_shape() = result_shape; + LiteralUtil::Reserve(ShapeUtil::ElementsIn(result_shape), + result_literal.get()); + LiteralUtil::EachCell( + literal, + [&](tensorflow::gtl::ArraySlice indices, NativeSrcT value) { + LiteralUtil::Set(result_literal.get(), indices, + static_cast(value)); + }); + return result_literal; +} + +template +/* static */ void LiteralUtil::Resize(int64 num_elements, NativeT value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + tensorflow::protobuf::RepeatedField* repeated_field = + GetMutableRepeatedField(literal); + repeated_field->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal); + +template +/* static */ std::unique_ptr +LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( + tensorflow::gtl::ArraySlice dimensions, NativeT value) { + Shape shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + primitive_util::NativeToPrimitiveType(), dimensions); + auto literal = MakeUnique(); + *literal->mutable_shape() = shape; + Reserve(ShapeUtil::ElementsIn(shape), literal.get()); + std::vector index(dimensions.size(), 0); + do { + Set(literal.get(), index, value); + } while (IndexUtil::BumpIndices(shape, &index)); + return literal; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc new file mode 100644 index 0000000000..09410d5c33 --- /dev/null +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -0,0 +1,622 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" + +#include + +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class LiteralUtilTest : public ::testing::Test { + protected: + LiteralUtilTest() { + Array4D arr4d({ + // clang-format off + { // i0=0 + { // i1=0 + {1, 2, 3}, // i2=0 + {4, 5, 6}, // i2=1 + {7, 8, 9}, // i2=2 + }, + { // i1=1 + {11, 12, 13}, + {14, 15, 16}, + {17, 18, 19}, + }, + }, + { // i0=1 + { // i1=0 + {101, 102, 103}, + {104, 105, 106}, + {107, 108, 109}, + }, + { // i1=1 + {201, 202, 203}, // i2=0 + {204, 205, 206}, // i2=1 + {207, 208, 209}, // i2=2 + }, + }, + // clang-format on + }); + + layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0}); + layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1}); + layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0}); + layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2}); + layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0}); + layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); + + literal_r4_2x2x3x3_dim0major_ = + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); + literal_r4_2x2x3x3_dim0minor_ = + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); + } + + Layout layout_r2_dim0major_; + Layout layout_r2_dim0minor_; + Layout layout_r3_dim0major_; + Layout layout_r3_dim0minor_; + Layout layout_r4_dim0major_; + Layout layout_r4_dim0minor_; + std::unique_ptr literal_r4_2x2x3x3_dim0major_; + std::unique_ptr literal_r4_2x2x3x3_dim0minor_; +}; + +TEST_F(LiteralUtilTest, LiteralScalarToString) { + auto true_lit = LiteralUtil::CreateR0(true); + ASSERT_EQ("true", LiteralUtil::ToString(*true_lit)); + + auto false_lit = LiteralUtil::CreateR0(false); + ASSERT_EQ("false", LiteralUtil::ToString(*false_lit)); + + auto u32_lit = LiteralUtil::CreateR0(42); + ASSERT_EQ("42", LiteralUtil::ToString(*u32_lit)); + + auto s32_lit = LiteralUtil::CreateR0(-999); + ASSERT_EQ("-999", LiteralUtil::ToString(*s32_lit)); + + auto f32_lit = LiteralUtil::CreateR0(3.14f); + ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); +} + +TEST_F(LiteralUtilTest, LiteralVectorToString) { + auto pred_vec = LiteralUtil::CreateR1({true, false, true}); + ASSERT_EQ("{101}", LiteralUtil::ToString(*pred_vec)); +} + +TEST_F(LiteralUtilTest, R2ToString) { + const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const string expected = R"(s32[3,2] { + { 1, 2 }, + { 3, 4 }, + { 5, 6 }, +})"; + ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); +} + +TEST_F(LiteralUtilTest, R3ToString) { + const auto literal = + LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const string expected = R"(s32[3,2,1] { +{ { 1 }, + { 2 } }, +{ { 3 }, + { 4 } }, +{ { 5 }, + { 6 } } +})"; + ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); +} + +TEST_F(LiteralUtilTest, TupleToString) { + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + const string expected = R"((f32[], f32[2,2]) ( +1, +f32[2,2] { + { 1, 2 }, + { 3, 4 }, +}, +))"; + ASSERT_EQ(expected, LiteralUtil::ToString(*tuple)); +} + +TEST_F(LiteralUtilTest, CreateR3FromArray3d) { + // clang-format off + Array3D array_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + // clang-format on + + auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); + EXPECT_MATCH(testing::PBToVec( + literal->shape().dimensions()), + testing::VectorMatcher({2, 3, 2})); + string result = LiteralUtil::ToString(*literal); + const string expected = R"(f32[2,3,2] { +{ { 1, 2 }, + { 3, 4 }, + { 5, 6 } }, +{ { 7, 8 }, + { 9, 10 }, + { 11, 12 } } +})"; + ASSERT_EQ(expected, result); +} + +TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { + // clang-format off + auto literal = LiteralUtil::CreateR4Projected({ + {1, 2}, + {1001, 1002}, + {2001, 2002}, + }, /*projection_p=*/1, /*projection_z=*/2); + // clang-format on + EXPECT_MATCH( + testing::PBToVec(literal->shape().dimensions()), + testing::VectorMatcher({1, 2, 3, 2})); + string result = LiteralUtil::ToString(*literal); + const string expected = R"(f32[1,2,3,2] { + { // i0=0 + { // i1=0 + {1, 2}, + {1001, 1002}, + {2001, 2002}, + }, + { // i1=1 + {1, 2}, + {1001, 1002}, + {2001, 2002}, + }, + }, +})"; + ASSERT_EQ(expected, result); +} + +TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { + EXPECT_MATCH( + testing::PBToVec( + literal_r4_2x2x3x3_dim0major_->shape().dimensions()), + testing::VectorMatcher({2, 2, 3, 3})); + string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_); + const string expected = R"(f32[2,2,3,3] { + { // i0=0 + { // i1=0 + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }, + { // i1=1 + {11, 12, 13}, + {14, 15, 16}, + {17, 18, 19}, + }, + }, + { // i0=1 + { // i1=0 + {101, 102, 103}, + {104, 105, 106}, + {107, 108, 109}, + }, + { // i1=1 + {201, 202, 203}, + {204, 205, 206}, + {207, 208, 209}, + }, + }, +})"; + ASSERT_EQ(expected, result); +} + +TEST_F(LiteralUtilTest, EachCellR2F32) { + // clang-format off + auto literal = LiteralUtil::CreateR2({ + {3.1f, 4.2f}, + {9.3f, 12.4f}, + }); + // clang-format on + std::vector> seen; + LiteralUtil::EachCellAsString( + *literal, + [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { + seen.emplace_back(indices[0], indices[1], value); + }); + + using Elem = std::tuple; + std::vector expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"), + Elem(1, 0, "9.3"), Elem(1, 1, "12.4")}; + EXPECT_EQ(expected, seen); +} + +TEST_F(LiteralUtilTest, ScalarEquality) { + // Test LiteralUtil::Equal with scalars. + auto f32_42 = LiteralUtil::CreateR0(42.0); + auto f32_42_clone = LiteralUtil::CreateR0(42.0); + + EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42)); + EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42_clone)); + + auto f32_123 = LiteralUtil::CreateR0(123.0); + EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f32_123)); + + auto f64_42 = LiteralUtil::CreateR0(42.0); + EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f64_42)); +} + +TEST_F(LiteralUtilTest, NonScalarEquality) { + // Test LiteralUtil::Equal with nonscalars. + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = + LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = LiteralUtil::CreateR0(1.0); + + EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix)); + EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix_clone)); + EXPECT_FALSE(LiteralUtil::Equal(*matrix, *matrix_different)); + EXPECT_FALSE(LiteralUtil::Equal(*matrix, *vector_literal)); + EXPECT_FALSE(LiteralUtil::Equal(*matrix, *scalar)); +} + +TEST_F(LiteralUtilTest, DifferentLayoutEquality) { + // Test LiteralUtil::Equal with literals which have different layouts. + auto colmajor = MakeUnique(); + *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); + *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + LiteralUtil::Reserve(4, colmajor.get()); + LiteralUtil::Set(colmajor.get(), {0, 0}, 1.0); + LiteralUtil::Set(colmajor.get(), {0, 1}, 2.0); + LiteralUtil::Set(colmajor.get(), {1, 0}, 3.0); + LiteralUtil::Set(colmajor.get(), {1, 1}, 4.0); + + auto rowmajor = MakeUnique(); + *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); + *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + LiteralUtil::Reserve(4, rowmajor.get()); + LiteralUtil::Set(rowmajor.get(), {0, 0}, 1.0); + LiteralUtil::Set(rowmajor.get(), {0, 1}, 2.0); + LiteralUtil::Set(rowmajor.get(), {1, 0}, 3.0); + LiteralUtil::Set(rowmajor.get(), {1, 1}, 4.0); + + EXPECT_TRUE(LiteralUtil::Equal(*rowmajor, *colmajor)); +} + +TEST_F(LiteralUtilTest, TupleEquality) { + // Test LiteralUtil::Equal with tuples. + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto scalar_clone = LiteralUtil::CreateR0(1.0); + auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); + EXPECT_TRUE(LiteralUtil::Equal(*tuple1, *tuple2)); + + // Tuple with elements reversed. + auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); + EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *reversed_tuple)); + + // Tuple with different value. + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto different_tuple = + LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); + EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *different_tuple)); +} + +TEST_F(LiteralUtilTest, IsAllTuple) { + auto element1 = LiteralUtil::CreateR0(0.0); + auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + + // Tuples should always return false for IsAll. + EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 0)); + EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 1)); +} + +TEST_F(LiteralUtilTest, IsAll) { + EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 0)); + EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 1)); + EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 1)); + EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 2)); + EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 0)); + EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 2)); + EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), -1)); + + // We shouldn't reinterpret int8_min as an unsigned type and then decide that + // it is equal to 255. + auto int8_min = std::numeric_limits::min(); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR0(255), int8_min)); + + EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0), 42)); + EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0001), 42)); + + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR1({100, 100, 100}), 100)); + EXPECT_FALSE(LiteralUtil::IsAll( + *LiteralUtil::CreateR1({100, 100, 100.001}), 100)); + + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 8}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 9}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + + auto uint64_max = std::numeric_limits::max(); + EXPECT_FALSE(LiteralUtil::IsAll( + *LiteralUtil::CreateR2( + {{uint64_max, uint64_max}, {uint64_max, uint64_max}}), + -1)); +} + +TEST_F(LiteralUtilTest, IsZero) { + auto scalar_zero = LiteralUtil::CreateR0(0.0f); + auto scalar_one = LiteralUtil::CreateR0(1.0f); + EXPECT_TRUE(LiteralUtil::IsZero(*scalar_zero, {})); + EXPECT_FALSE(LiteralUtil::IsZero(*scalar_one, {})); + + auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + EXPECT_FALSE(LiteralUtil::IsZero(*array, {0, 1})); + EXPECT_TRUE(LiteralUtil::IsZero(*array, {0, 2})); + EXPECT_TRUE(LiteralUtil::IsZero(*array, {1, 1})); + EXPECT_FALSE(LiteralUtil::IsZero(*array, {1, 2})); +} + +template +class LiteralUtilTestTemplated : public ::testing::Test {}; + +using TestedTypes = ::testing::Types; +TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); + +TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { + // Make a non-integer for floating point types. + TypeParam half = TypeParam(1) / TypeParam(2); + auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); + const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); + const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); + + auto data01 = LiteralUtil::Relayout(*data, layout01); + EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); + EXPECT_TRUE(LiteralUtil::Equal(*data, *data01)); + + auto data10 = LiteralUtil::Relayout(*data, layout10); + EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); + EXPECT_TRUE(LiteralUtil::Equal(*data, *data10)); +} + +TEST_F(LiteralUtilTest, ReshapeR0) { + auto original = LiteralUtil::CreateR0(1.7f); + auto reshape = + LiteralUtil::Reshape(*original, /*shape=*/{}).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); +} + +TEST_F(LiteralUtilTest, ReshapeR4) { + // clang-format off + // F32[1x3x2x4] + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + // F32[1x3x4x2] + auto expected = LiteralUtil::CreateR3WithLayout({ + {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, + {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, + {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, + }, layout_r3_dim0major_); + // clang-format on + auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); +} + +TEST_F(LiteralUtilTest, TransposeR0) { + auto original = LiteralUtil::CreateR0(1.7f); + auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); + EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); +} + +TEST_F(LiteralUtilTest, TransposeR4) { + // clang-format off + // F32[1x3x2x4] + auto original = LiteralUtil::CreateR4({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}); + // clang-format on + auto reshape = + LiteralUtil::Transpose(*original, /*permutation=*/{2, 3, 0, 1}); + + LiteralUtil::EachCell( + *reshape, [&](tensorflow::gtl::ArraySlice indices, float value) { + EXPECT_EQ(value, + LiteralUtil::Get(*original, {indices[2], indices[3], + indices[0], indices[1]})); + }); +} + +TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { + // Tests that using Relayout on an array is equivalent to creating it in the + // target layout in the first place. + auto dim0minor_relaid_to_dim0major = LiteralUtil::Relayout( + *literal_r4_2x2x3x3_dim0minor_, layout_r4_dim0major_); + EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0major_, + *dim0minor_relaid_to_dim0major)); + + auto dim0major_relaid_to_dim0minor = LiteralUtil::Relayout( + *literal_r4_2x2x3x3_dim0major_, layout_r4_dim0minor_); + EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0minor_, + *dim0major_relaid_to_dim0minor)); +} + +TEST_F(LiteralUtilTest, TestR2LinearLayout) { + // Test expected memory layout of R2 dim0-minor (column-major) literal. + auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( + {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); + EXPECT_EQ(mat_dim0minor->s32s_size(), 6); + EXPECT_MATCH(testing::PBToVec(mat_dim0minor->s32s()), + testing::VectorMatcher({1, 4, 2, 5, 3, 6})); + + // Test expected memory layout when using Relayout to row major. + auto relaid_mat_to_dim0major = + LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_); + EXPECT_MATCH(testing::PBToVec(relaid_mat_to_dim0major->s32s()), + testing::VectorMatcher({1, 2, 3, 4, 5, 6})); + + // Test expected memory layout of R2 created with dim0-major (row-major). + auto mat_dim0major = LiteralUtil::CreateR2WithLayout( + {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); + EXPECT_EQ(mat_dim0major->s32s_size(), 6); + EXPECT_MATCH(testing::PBToVec(mat_dim0major->s32s()), + testing::VectorMatcher({1, 2, 3, 4, 5, 6})); + + // Test expected memory layout when using Relayout to column major. + auto relaid_mat_to_dim0minor = + LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_); + EXPECT_MATCH(testing::PBToVec(relaid_mat_to_dim0minor->s32s()), + testing::VectorMatcher({1, 4, 2, 5, 3, 6})); +} + +TEST_F(LiteralUtilTest, TestR3LinearLayout) { + // Test expected memory layout of R3 dim0-minor (column-major) literal. + Array3D arr3d( + // clang-format off + { + { + {1, 2, 3}, + {4, 5, 6}, + }, + { + {7, 8, 9}, + {10, 11, 12}, + }, + }); // clang-format on + auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0minor_); + + EXPECT_EQ(lit_dim0minor->s32s_size(), 12); + std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; + EXPECT_MATCH(testing::PBToVec(lit_dim0minor->s32s()), + testing::VectorMatcher(expected_dim0minor)); + + // Test expected memory layout when using Relayout to row major. + auto relaid_lit_to_dim0major = + LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_); + std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + EXPECT_MATCH(testing::PBToVec(relaid_lit_to_dim0major->s32s()), + testing::VectorMatcher(expected_dim0major)); + + // Test expected memory layout of R3 created with dim0-major (row-major). + auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0major_); + EXPECT_EQ(lit_dim0major->s32s_size(), 12); + EXPECT_MATCH(testing::PBToVec(lit_dim0major->s32s()), + testing::VectorMatcher(expected_dim0major)); + + // Test expected memory layout when using Relayout to column major. + auto relaid_lit_to_dim0minor = + LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_); + EXPECT_MATCH(testing::PBToVec(relaid_lit_to_dim0minor->s32s()), + testing::VectorMatcher(expected_dim0minor)); +} + +TEST_F(LiteralUtilTest, SliceR0S32) { + auto input = LiteralUtil::CreateR0(1); + auto result = LiteralUtil::Slice(*input, {}, {}); + EXPECT_TRUE(LiteralUtil::Equal(*input, *result)); +} + +TEST_F(LiteralUtilTest, SliceR1F32) { + auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto result = LiteralUtil::Slice(*input, {3}, {4}); + auto expected = LiteralUtil::CreateR1({4.0}); + EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); +} + +TEST_F(LiteralUtilTest, SliceR2U32) { + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto result = LiteralUtil::Slice(*input_3x4, {0, 2}, {2, 4}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); + EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); +} + +TEST_F(LiteralUtilTest, SliceR3U32Full) { + auto input_2x3x2 = LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + auto result = LiteralUtil::Slice(*input_2x3x2, {0, 0, 0}, {2, 3, 2}); + EXPECT_TRUE(LiteralUtil::Equal(*input_2x3x2, *result)); +} + +TEST_F(LiteralUtilTest, PopulateR1S64) { + Literal output; + LiteralUtil::PopulateR1({77}, &output); + auto expected = LiteralUtil::CreateR1({77}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateR2U64) { + Literal output; + LiteralUtil::PopulateR1({{77, 88}}, &output); + auto expected = LiteralUtil::CreateR1({{77, 88}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { + Literal output; + LiteralUtil::PopulateWithValue(2.5f, {}, &output); + auto expected = LiteralUtil::CreateR0(2.5f); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { + Literal output; + LiteralUtil::PopulateWithValue(-7, {3}, &output); + auto expected = LiteralUtil::CreateR1({-7, -7, -7}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { + Literal output; + LiteralUtil::PopulateWithValue(42, {2, 2}, &output); + auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h new file mode 100644 index 0000000000..51d0d5f86f --- /dev/null +++ b/tensorflow/compiler/xla/map_util.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ + +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +// FindOrDie returns a const reference to the value associated with +// the given key if it exists. Crashes otherwise. +// +// This is intended as a replacement for operator[] as an rvalue (for reading) +// when the key is guaranteed to exist. +template +const typename Collection::value_type::second_type& FindOrDie( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + CHECK(it != collection.end()) << "Map key not found: " << key; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename Collection::value_type::second_type& FindOrDie( + Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + CHECK(it != collection.end()) << "Map key not found: " << key; + return it->second; +} + +// Inserts the key-value pair into the collection. Dies if key was already +// present. +template +void InsertOrDie(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + auto p = collection->insert(std::make_pair(key, data)); + CHECK(p.second) << "duplicate key: " << key; +} + +// Returns true if and only if the given collection contains the given key. +template +bool ContainsKey(const Collection& collection, const Key& key) { + return collection.find(key) != collection.end(); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc new file mode 100644 index 0000000000..21766a2a0c --- /dev/null +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/packed_literal_reader.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file) + : file_(file), offset_(0) {} + +PackedLiteralReader::~PackedLiteralReader() { delete file_; } + +StatusOr> PackedLiteralReader::Read( + const Shape& shape, const Layout* layout) { + VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) + << " layout: " + << (layout == nullptr ? "" : layout->ShortDebugString()); + auto result = MakeUnique(); + *result->mutable_shape() = shape; + if (layout != nullptr) { + TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*layout, shape)); + *result->mutable_shape()->mutable_layout() = *layout; + } + + if (shape.element_type() != F32) { + return Unimplemented( + "not yet implemented element type for packed literal reading: %s", + PrimitiveType_Name(shape.element_type()).c_str()); + } + + int64 elements = ShapeUtil::ElementsIn(shape); + LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), + result.get()); + tensorflow::protobuf::RepeatedField* field = result->mutable_f32s(); + char* data = tensorflow::bit_cast(field->mutable_data()); + uint64 bytes = elements * sizeof(float); + tensorflow::StringPiece sp; + auto s = file_->Read(offset_, bytes, &sp, data); + offset_ += sp.size(); + if (!s.ok()) { + return s; + } else { + // Success: make sure we move the data into the right place if the Read + // call decided to return data in somewhere other than "data". + CHECK_EQ(sp.size(), bytes); + if (sp.data() != data) { + memcpy(data, sp.data(), sp.size()); + } + } + VLOG(3) << "read shape from file: " << ShapeUtil::HumanString(shape); + return std::move(result); +} + +bool PackedLiteralReader::IsExhausted() const { + // Try to read a single byte from offset_. If we can't, we've + // exhausted the data. + char single_byte[1]; + tensorflow::StringPiece sp; + auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); + return !s.ok(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h new file mode 100644 index 0000000000..563d978cf5 --- /dev/null +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PACKED_LITERAL_READER_H_ +#define TENSORFLOW_COMPILER_XLA_PACKED_LITERAL_READER_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Reads packed data from a metadata-less file as requested by a user (who must +// know its internal format). These are yielded as (structured) literal values. +class PackedLiteralReader { + public: + // Ownership of file is passed to this instance -- this instance takes + // responsibility for closing it. + explicit PackedLiteralReader(tensorflow::RandomAccessFile* file); + ~PackedLiteralReader(); + + // Yields the next packed literal with shape "shape" as read from the + // underlying file stream. + // + // Layout is optional. If it is not provided, no layout is set on the literal + // that is produced. + StatusOr> Read(const Shape& shape, + const Layout* layout = nullptr); + + // Returns whether the input file has been fully exhausted; i.e. all available + // packed literals have been read and we're at the end of the file. + bool IsExhausted() const; + + private: + tensorflow::RandomAccessFile* file_; // We own and close in our destructor + uint64 offset_; // Next file offset to read from + + TF_DISALLOW_COPY_AND_ASSIGN(PackedLiteralReader); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PACKED_LITERAL_READER_H_ diff --git a/tensorflow/compiler/xla/port/BUILD b/tensorflow/compiler/xla/port/BUILD new file mode 100644 index 0000000000..6fc5f1185c --- /dev/null +++ b/tensorflow/compiler/xla/port/BUILD @@ -0,0 +1,33 @@ +licenses(["notice"]) # Apache 2.0 + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), + visibility = ["//tensorflow/compiler/xla:internal"], +) + +cc_library( + name = "initialize", + hdrs = ["initialize.h"], + visibility = [ + "//tensorflow/compiler/xla:__subpackages__", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/port/initialize.h b/tensorflow/compiler/xla/port/initialize.h new file mode 100644 index 0000000000..13d9632f97 --- /dev/null +++ b/tensorflow/compiler/xla/port/initialize.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_ +#define TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_ + +#undef REGISTER_MODULE_INITIALIZER + +namespace xla { + +class Initializer { + public: + typedef void (*InitializerFunc)(); + explicit Initializer(InitializerFunc func) { func(); } +}; + +} // namespace xla + +#define REGISTER_INITIALIZER(type, name, body) \ + static void google_init_##type##_##name() { body; } \ + xla::Initializer google_initializer_##type##_##name( \ + google_init_##type##_##name) + +#define REGISTER_MODULE_INITIALIZER(name, body) \ + REGISTER_INITIALIZER(module, name, body) + +#endif // TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_ diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc new file mode 100644 index 0000000000..e3909ae8e9 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -0,0 +1,133 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/primitive_util.h" + +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace primitive_util { + +template <> +PrimitiveType NativeToPrimitiveType() { + return PRED; +} + +// Unsigned integer +template <> +PrimitiveType NativeToPrimitiveType() { + return U8; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return U16; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return U32; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return U64; +} + +// Signed integer +template <> +PrimitiveType NativeToPrimitiveType() { + return S8; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return S16; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return S32; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return S64; +} + +// Floating point +template <> +PrimitiveType NativeToPrimitiveType() { + return F32; +} + +template <> +PrimitiveType NativeToPrimitiveType() { + return F64; +} + +bool IsFloatingPointType(PrimitiveType type) { + return type == F16 || type == F32 || type == F64; +} + +bool IsSignedIntegralType(PrimitiveType type) { + return type == S8 || type == S16 || type == S32 || type == S64; +} + +bool IsUnsignedIntegralType(PrimitiveType type) { + return type == U8 || type == U16 || type == U32 || type == U64; +} + +bool IsIntegralType(PrimitiveType type) { + return IsUnsignedIntegralType(type) || IsSignedIntegralType(type); +} + +int BitWidth(PrimitiveType type) { + switch (type) { + case PRED: + return 1; + + case S8: + case U8: + return 8; + + case S16: + case U16: + case F16: + return 16; + + case U32: + case S32: + case F32: + return 32; + + case U64: + case S64: + case F64: + return 64; + + case TUPLE: + LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; + + case OPAQUE: + LOG(FATAL) << "OPAQUE is an invalid type for BitWidth"; + + default: + LOG(FATAL) << "Unhandled primitive type " << type; + } +} + +} // namespace primitive_util +} // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h new file mode 100644 index 0000000000..78f0ee6f59 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util.h @@ -0,0 +1,157 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for dealing with XLA primitive types. + +#ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace primitive_util { + +// Returns the XLA primitive type (eg, F32) corresponding to the given +// template parameter native type (eg, float). +template +PrimitiveType NativeToPrimitiveType() { + // Make the expression depend on the template parameter NativeT so + // that this compile-time error only apperas if this function is + // instantiated with some concrete type that is not specialized + // below. + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + return PRIMITIVE_TYPE_INVALID; +} + +// Declarations of specializations for each native type which correspond to a +// XLA primitive type. +template <> +PrimitiveType NativeToPrimitiveType(); + +// Unsigned integer +template <> +PrimitiveType NativeToPrimitiveType(); + +template <> +PrimitiveType NativeToPrimitiveType(); + +template <> +PrimitiveType NativeToPrimitiveType(); + +template <> +PrimitiveType NativeToPrimitiveType(); + +// Signed integer +template <> +PrimitiveType NativeToPrimitiveType(); + +template <> +PrimitiveType NativeToPrimitiveType(); + +template <> +PrimitiveType NativeToPrimitiveType(); + +template <> +PrimitiveType NativeToPrimitiveType(); + +// Floating point +template <> +PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); + +bool IsFloatingPointType(PrimitiveType type); + +bool IsSignedIntegralType(PrimitiveType type); + +bool IsUnsignedIntegralType(PrimitiveType type); + +bool IsIntegralType(PrimitiveType type); + +// Returns the number of bits in the representation for a given type. +int BitWidth(PrimitiveType type); + +// Returns the native type (eg, float) corresponding to the given template +// parameter XLA primitive type (eg, F32). +template +struct PrimitiveTypeToNative; + +// Declarations of specializations for each native type which correspond to a +// XLA primitive type. +template <> +struct PrimitiveTypeToNative { + using type = bool; +}; + +// Unsigned integer +template <> +struct PrimitiveTypeToNative { + using type = uint8; +}; + +template <> +struct PrimitiveTypeToNative { + using type = uint16; +}; + +template <> +struct PrimitiveTypeToNative { + using type = uint32; +}; + +template <> +struct PrimitiveTypeToNative { + using type = uint64; +}; + +// Signed integer +template <> +struct PrimitiveTypeToNative { + using type = int8; +}; + +template <> +struct PrimitiveTypeToNative { + using type = int16; +}; + +template <> +struct PrimitiveTypeToNative { + using type = int32; +}; + +template <> +struct PrimitiveTypeToNative { + using type = int64; +}; + +// Floating point +template <> +struct PrimitiveTypeToNative { + using type = float; +}; +template <> +struct PrimitiveTypeToNative { + using type = double; +}; + +} // namespace primitive_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_ diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc new file mode 100644 index 0000000000..adb2e99ad2 --- /dev/null +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace protobuf_util { + +bool ProtobufEquals(const tensorflow::protobuf::Message& m1, + const tensorflow::protobuf::Message& m2) { + // This is a bit fast and loose, but avoids introducing a dependency on + // the much more complex protobuf::util::MessageDifferencer class. For + // our purposes we just say that two protobufs are equal if their serialized + // representations are equal. + string serialized1, serialized2; + m1.AppendToString(&serialized1); + m2.AppendToString(&serialized2); + return (serialized1 == serialized2); +} + +} // namespace protobuf_util +} // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h new file mode 100644 index 0000000000..36247f1bde --- /dev/null +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ + +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { +namespace protobuf_util { + +// Returns true if m1 is equal to m2. +// +// WARNING: We use protocol buffer serialization and then check for +// equality of the serialized representation, which may miss some +// cases of equality. However, for the purposes of the XLA code +// base, this form of equality checking is sufficient. +extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, + const tensorflow::protobuf::Message& m2); +} // namespace protobuf_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h new file mode 100644 index 0000000000..fa67030313 --- /dev/null +++ b/tensorflow/compiler/xla/ptr_util.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ + +// Utility functions for pointers. + +#include + +#include +#include +#include + +namespace xla { + +namespace internal { + +// Trait to select overloads and return types for MakeUnique. +template +struct MakeUniqueResult { + using scalar = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using array = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using invalid = void; +}; + +} // namespace internal + +// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type. +// Example: +// X* NewX(int, int); +// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr. +// +// WrapUnique is useful for capturing the output of a raw pointer factory. +// However, prefer 'MakeUnique(args...) over 'WrapUnique(new T(args...))'. +// auto x = WrapUnique(new X(1, 2)); // works, but nonideal. +// auto x = MakeUnique(1, 2); // safer, standard, avoids raw 'new'. +// +// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]). +template +std::unique_ptr WrapUnique(T* ptr) { + static_assert(!std::is_array::value || std::extent::value != 0, + "types T[0] or T[] are unsupported"); + return std::unique_ptr(ptr); +} + +template +typename internal::MakeUniqueResult::scalar MakeUnique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +// Overload for array of unknown bound. +// The allocation of arrays needs to use the array form of new, +// and cannot take element constructor arguments. +template +typename internal::MakeUniqueResult::array MakeUnique(size_t n) { + return std::unique_ptr(new typename std::remove_extent::type[n]()); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc new file mode 100644 index 0000000000..f03b158fa7 --- /dev/null +++ b/tensorflow/compiler/xla/reference_util.cc @@ -0,0 +1,540 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/reference_util.h" + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +/* static */ std::unique_ptr> ReferenceUtil::TransposeArray2D( + const Array2D& operand) { + auto result = MakeUnique>(operand.width(), operand.height()); + for (int64 w = 0; w < operand.width(); ++w) { + for (int64 h = 0; h < operand.height(); ++h) { + (*result)(w, h) = operand(h, w); + } + } + + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = MakeUnique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + __xla_cpu_runtime_EigenSingleThreadedMatMulF32( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = MakeUnique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + __xla_cpu_runtime_EigenSingleThreadedMatMulF64( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( + const Array2D& input) { + auto result = MakeUnique>(input.height(), input.width()); + for (int64 rowno = 0; rowno < input.height(); ++rowno) { + for (int64 colno = 0; colno < input.height(); ++colno) { + (*result)(rowno, colno) = input(rowno, colno); + } + } + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::ConvArray4D( + const Array4D& lhs, const Array4D& rhs, + std::pair kernel_stride, Padding padding) { + return ConvArray4DGeneralDimensions( + lhs, rhs, kernel_stride, padding, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); +} + +/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width, + int64 window_len, int64 stride, + Padding padding) { + if (padding == Padding::kValid) { + return window_util::StridedBound(unpadded_width, window_len, stride); + } + return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); +} + +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( + const Array4D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + window_counts[i] = + WindowCount(dim_lengths[i], window[i], stride[i], padding); + pad_low[i] = padding_both[i].first; + } + auto result = MakeUnique>(window_counts[0], window_counts[1], + window_counts[2], window_counts[3]); + // Do a full 4D reduce window. + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { + for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { + for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { + int64 i0_base = i0 * stride[0] - pad_low[0]; + int64 i1_base = i1 * stride[1] - pad_low[1]; + int64 i2_base = i2 * stride[2] - pad_low[2]; + int64 i3_base = i3 * stride[3] - pad_low[3]; + + float val = init; + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { + for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { + for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { + if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && + i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && + i0_base + i0_win < operand.n1() && + i1_base + i1_win < operand.n2() && + i2_base + i2_win < operand.n3() && + i3_base + i3_win < operand.n4()) { + val += operand(i0_base + i0_win, i1_base + i1_win, + i2_base + i2_win, i3_base + i3_win); + } + } + } + } + } + (*result)(i0, i1, i2, i3) = val; + } + } + } + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::SelectAndScatter4DGePlus( + const Array4D& operand, const Array4D& source, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, bool same_padding) { + Padding padding = same_padding ? Padding::kSame : Padding::kValid; + auto result = MakeUnique>(operand.n1(), operand.n2(), + operand.n3(), operand.n4()); + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + // Fill the output, with the initial value. + result->Fill(init); + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + window_counts[i] = + WindowCount(dim_lengths[i], window[i], stride[i], padding); + pad_low[i] = padding_both[i].first; + } + CHECK_EQ(window_counts[0], source.n1()); + CHECK_EQ(window_counts[1], source.n2()); + CHECK_EQ(window_counts[2], source.n3()); + CHECK_EQ(window_counts[3], source.n4()); + + // Do a full 4D select and Scatter. + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { + for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { + for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { + // Now we are inside a window and need to find the max and the argmax. + int64 i0_base = i0 * stride[0] - pad_low[0]; + int64 i1_base = i1 * stride[1] - pad_low[1]; + int64 i2_base = i2 * stride[2] - pad_low[2]; + int64 i3_base = i3 * stride[3] - pad_low[3]; + int64 scatter_0 = (i0_base >= 0) ? i0_base : 0; + int64 scatter_1 = (i1_base >= 0) ? i1_base : 0; + int64 scatter_2 = (i2_base >= 0) ? i2_base : 0; + int64 scatter_3 = (i3_base >= 0) ? i3_base : 0; + float val = operand(scatter_0, scatter_1, scatter_2, scatter_3); + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { + for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { + for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { + if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && + i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && + i0_base + i0_win < operand.n1() && + i1_base + i1_win < operand.n2() && + i2_base + i2_win < operand.n3() && + i3_base + i3_win < operand.n4()) { + float tmp = operand(i0_base + i0_win, i1_base + i1_win, + i2_base + i2_win, i3_base + i3_win); + if (tmp >= val) { + val = tmp; + scatter_0 = i0_base + i0_win; + scatter_1 = i1_base + i1_win; + scatter_2 = i2_base + i2_win; + scatter_3 = i3_base + i3_win; + } + } + } + } + } + } + (*result)(scatter_0, scatter_1, scatter_2, scatter_3) += + source(i0, i1, i2, i3); + } + } + } + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::ConvArray4DGeneralDimensions( + const Array4D& lhs, const Array4D& rhs, + std::pair kernel_stride, Padding padding, + ConvolutionDimensionNumbers dimension_numbers) { + return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, + {1, 1}, {1, 1}, dimension_numbers); +} + +/* static */ std::unique_ptr> +ReferenceUtil::ConvArray4DGeneralDimensionsDilated( + const Array4D& lhs, const Array4D& rhs, + std::pair kernel_stride, Padding padding, + std::pair lhs_dilation, std::pair rhs_dilation, + ConvolutionDimensionNumbers dnums) { + std::array lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}}; + std::array rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}}; + + const int64 ksy = kernel_stride.first; + const int64 ksx = kernel_stride.second; + const int64 dy = lhs_dilation.first; + const int64 dx = lhs_dilation.second; + const int64 dky = rhs_dilation.first; + const int64 dkx = rhs_dilation.second; + CHECK_GE(dky, 1); + CHECK_GE(dkx, 1); + CHECK_GE(dy, 1); + CHECK_GE(dx, 1); + + // Get all dimension sizes in lhs and rhs based on the given convolution + // dimension configuration. + const int64 ix = window_util::DilatedBound( + lhs_dimensions[dnums.spatial_dimensions(1)], dx); + const int64 iy = window_util::DilatedBound( + lhs_dimensions[dnums.spatial_dimensions(0)], dy); + const int64 iz = lhs_dimensions[dnums.feature_dimension()]; + const int64 samples = lhs_dimensions[dnums.batch_dimension()]; + const int64 kx = window_util::DilatedBound( + rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx); + const int64 ky = window_util::DilatedBound( + rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky); + const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()]; + { + const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()]; + CHECK_EQ(kiz, iz); + } + + if (padding == Padding::kSame) { + // We reject same padding with kernel striding, since it's somewhat + // nonsensical. We can always follow up to implement this with the desired + // semantics if anybody actually uses it. + CHECK_EQ(1, ksy); + CHECK_EQ(1, ksx); + } + + const int64 ox = + padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx); + const int64 oy = + padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy); + const int64 istartx = + padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2; + const int64 istarty = + padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2; + // Create the output result array and reset the values to 0. + std::array result_dimensions; + result_dimensions[dnums.batch_dimension()] = samples; + result_dimensions[dnums.feature_dimension()] = oz; + result_dimensions[dnums.spatial_dimensions(0)] = oy; + result_dimensions[dnums.spatial_dimensions(1)] = ox; + auto result = + MakeUnique>(result_dimensions[0], result_dimensions[1], + result_dimensions[2], result_dimensions[3]); + result->Fill(0.0); + + // Lambda to access the lhs operand at the given 4D index. + const auto lhs_element = [&](int64 batch, int64 feature, int64 height, + int64 width) { + if (height % dy != 0 || width % dx != 0) { + return 0.0f; + } + + std::array index; + index[dnums.batch_dimension()] = batch; + index[dnums.feature_dimension()] = feature; + index[dnums.spatial_dimensions(0)] = height / dy; + index[dnums.spatial_dimensions(1)] = width / dx; + return lhs(index[0], index[1], index[2], index[3]); + }; + + // Lambda to access the rhs operand at the given 4D index. + const auto rhs_element = [&](int64 kernel_output_feature, + int64 kernel_input_feature, int64 height, + int64 width) { + CHECK_EQ(height % dky, 0); + CHECK_EQ(width % dkx, 0); + std::array index; + index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; + index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; + index[dnums.kernel_spatial_dimensions(0)] = height / dky; + index[dnums.kernel_spatial_dimensions(1)] = width / dkx; + return rhs(index[0], index[1], index[2], index[3]); + }; + + // Lambda to access the result data at the given 4D index. + const auto result_element = [&](int64 batch, int64 kernel_output_feature, + int64 height, int64 width) -> float& { + std::array index; + index[dnums.batch_dimension()] = batch; + index[dnums.feature_dimension()] = kernel_output_feature; + index[dnums.spatial_dimensions(0)] = height; + index[dnums.spatial_dimensions(1)] = width; + return (*result)(index[0], index[1], index[2], index[3]); + }; + + for (int64 oyi = 0; oyi < oy; ++oyi) { + for (int64 oxi = 0; oxi < ox; ++oxi) { + for (int64 sample = 0; sample < samples; ++sample) { + for (int64 izi = 0; izi < iz; ++izi) { + for (int64 ozi = 0; ozi < oz; ++ozi) { + for (int64 kyi = 0; kyi < ky; kyi += dky) { + for (int64 kxi = 0; kxi < kx; kxi += dkx) { + int64 iyi = istarty + ksy * oyi + kyi; + int64 ixi = istartx + ksx * oxi + kxi; + float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) + ? 0.0 + : lhs_element(sample, izi, iyi, ixi); + float gain = rhs_element(ozi, izi, kyi, kxi); + float addend = input * gain; + result_element(sample, ozi, oyi, oxi) += addend; + } + } + } + } + } + } + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceToColArray2D( + const Array2D& matrix, float init, + std::function reduce_function) { + int64 rows = matrix.height(); + int64 cols = matrix.width(); + auto result = MakeUnique>(); + for (int64 i = 0; i < rows; ++i) { + float acc = init; + for (int64 j = 0; j < cols; ++j) { + acc = reduce_function(acc, matrix(i, j)); + } + result->push_back(acc); + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceToRowArray2D( + const Array2D& matrix, float init, + std::function reduce_function) { + int64 rows = matrix.height(); + int64 cols = matrix.width(); + auto result = MakeUnique>(); + for (int64 i = 0; i < cols; ++i) { + float acc = init; + for (int64 j = 0; j < rows; ++j) { + acc = reduce_function(acc, matrix(j, i)); + } + result->push_back(acc); + } + return result; +} + +/*static*/ std::vector ReferenceUtil::Reduce4DTo1D( + const Array4D& array, float init, + tensorflow::gtl::ArraySlice dims, + std::function reduce_function) { + std::vector result; + CHECK_EQ(dims.size(), 3); + const std::set dim_set(dims.begin(), dims.end()); + CHECK_EQ(dim_set.size(), 3); + for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { + for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); + ++a1) { + for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); + ++a2) { + for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); + ++a3) { + float accumulator = init; + for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); + ++i0) { + for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); + ++i1) { + for (int64 i2 = 0; + i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { + for (int64 i3 = 0; + i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { + accumulator = reduce_function( + accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); + } + } + } + } + result.push_back(accumulator); + } + } + } + } + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( + const Array3D& array, float init, + tensorflow::gtl::ArraySlice dims, + std::function reduce_function) { + CHECK_EQ(dims.size(), 1); + int64 rows = dims[0] == 0 ? array.n2() : array.n1(); + int64 cols = dims[0] == 2 ? array.n2() : array.n3(); + auto result = MakeUnique>(rows, cols); + result->Fill(init); + for (int i0 = 0; i0 < array.n1(); ++i0) { + for (int i1 = 0; i1 < array.n2(); ++i1) { + for (int i2 = 0; i2 < array.n3(); ++i2) { + int64 row = dims[0] == 0 ? i1 : i0; + int64 col = dims[0] == 2 ? i1 : i2; + (*result)(row, col) = + reduce_function((*result)(row, col), array(i0, i1, i2)); + } + } + } + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::MapArray2D( + const Array2D& matrix, + const std::function& map_function) { + int64 rows = matrix.height(); + int64 cols = matrix.width(); + auto result = MakeUnique>(rows, cols); + for (int64 i = 0; i < rows; ++i) { + for (int64 j = 0; j < cols; ++j) { + (*result)(i, j) = map_function(matrix(i, j)); + } + } + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::MapArray2D( + const Array2D& lhs, const Array2D& rhs, + const std::function& map_function) { + CHECK_EQ(lhs.height(), rhs.height()); + CHECK_EQ(lhs.width(), rhs.width()); + int64 rows = lhs.height(); + int64 cols = rhs.width(); + auto result = MakeUnique>(rows, cols); + for (int64 i = 0; i < rows; ++i) { + for (int64 j = 0; j < cols; ++j) { + (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); + } + } + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::MapWithIndexArray2D( + const Array2D& matrix, + const std::function& map_function) { + int64 rows = matrix.height(); + int64 cols = matrix.width(); + auto result = MakeUnique>(rows, cols); + for (int64 i = 0; i < rows; ++i) { + for (int64 j = 0; j < cols; ++j) { + (*result)(i, j) = map_function(matrix(i, j), i, j); + } + } + return result; +} + +/* static */ std::unique_ptr> ReferenceUtil::PadArray2D( + const Array2D& operand, const PaddingConfig& padding, + const float pad) { + int64 in0 = operand.n1(); + int64 high_padding0 = padding.dimensions(0).edge_padding_high(); + int64 low_padding0 = padding.dimensions(0).edge_padding_low(); + int64 interior_padding0 = padding.dimensions(0).interior_padding(); + int64 out0 = + in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; + int64 in1 = operand.n2(); + int64 high_padding1 = padding.dimensions(1).edge_padding_high(); + int64 low_padding1 = padding.dimensions(1).edge_padding_low(); + int64 interior_padding1 = padding.dimensions(1).interior_padding(); + int64 out1 = + in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; + auto result = MakeUnique>(out0, out1); + result->Fill(pad); + int64 i0 = 0; + for (int64 o0 = low_padding0; o0 < out0 - high_padding0; + o0 += interior_padding0 + 1) { + int64 i1 = 0; + for (int64 o1 = low_padding1; o1 < out1 - high_padding1; + o1 += interior_padding1 + 1) { + (*result)(o0, o1) = operand(i0, i1); + ++i1; + } + ++i0; + } + return result; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h new file mode 100644 index 0000000000..27421b2ac4 --- /dev/null +++ b/tensorflow/compiler/xla/reference_util.h @@ -0,0 +1,382 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Utility class for reference implementations of linear algebra routines. +class ReferenceUtil { + public: + // Returns the result of a transpose operation on the input matrix. + static std::unique_ptr> TransposeArray2D( + const Array2D& operand); + + // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + + // Converts the input operand to use f64 values instead of f32 values. + static std::unique_ptr> Array2DF32ToF64( + const Array2D& input); + + // Returns the result of a convolution `lhs rhs`, with the default + // convolution dimension numbers returned from + // ComputationBuilder::CreateDefaultConvDimensionNumbers(). + static std::unique_ptr> ConvArray4D( + const Array4D& lhs, const Array4D& rhs, + std::pair kernel_stride, Padding padding); + + // Returns the result of a convolution `lhs rhs`, with the given + // convolution dimension numbers. + static std::unique_ptr> ConvArray4DGeneralDimensions( + const Array4D& lhs, const Array4D& rhs, + std::pair kernel_stride, Padding padding, + ConvolutionDimensionNumbers dimension_numbers); + + // Returns the result of a convolution `lhs rhs`, with the given + // dilation factors. + static std::unique_ptr> ConvArray4DGeneralDimensionsDilated( + const Array4D& lhs, const Array4D& rhs, + std::pair stride, Padding padding, + std::pair lhs_dilation, + std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); + + // Returns the result of reducing a matrix to a column vector. init is the + // initial value for the reduce operation, and reduce_function is the function + // to apply for each reduction step. + static std::unique_ptr> ReduceToColArray2D( + const Array2D& matrix, float init, + std::function reduce_function); + + // Returns the result of reducing a matrix to a row vector. init is the + // initial value for the reduce operation, and reduce_function is the function + // to apply for each reduction step. + static std::unique_ptr> ReduceToRowArray2D( + const Array2D& matrix, float init, + std::function reduce_function); + + // Performs a R2=>R1 reduction by reducing away the dimension specified in + // 'dimension_to_reduce'. + template + static std::vector ReduceR2ToR1(const Array2D& input, + int dimension_to_reduce, T init, + std::function freduce) { + std::vector result(dimension_to_reduce == 0 ? input.n2() : input.n1(), + init); + for (int i0 = 0; i0 < input.n1(); ++i0) { + for (int i1 = 0; i1 < input.n2(); ++i1) { + int output = dimension_to_reduce == 0 ? i1 : i0; + result[output] = freduce(result[output], input(i0, i1)); + } + } + return result; + } + + // Returns the result of reducing the 4D array to a vector, reducing away + // the dimensions specified in dims. + static std::vector Reduce4DTo1D( + const Array4D& array, float init, + tensorflow::gtl::ArraySlice dims, + std::function reduce_function); + + // Returns the result of reducing the 3D array to a 2D array, reducing away + // the dimensions specified in dims. + static std::unique_ptr> Reduce3DTo2D( + const Array3D& array, float init, + tensorflow::gtl::ArraySlice dims, + std::function reduce_function); + + // Applies map_function to each element in the input (2D array) and returns + // the result. + static std::unique_ptr> MapArray2D( + const Array2D& matrix, + const std::function& map_function); + + // Applies map_function to each pair of corresponding elements in the two + // inputs arrays and returns the result. + static std::unique_ptr> MapArray2D( + const Array2D& lhs, const Array2D& rhs, + const std::function& map_function); + + // Number of windows in a given dimension. Calculation taken from + // xla::MakePadding(). + static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride, + Padding padding); + + // Performs a 4D window reduction with Add as the function to apply. + static std::unique_ptr> ReduceWindow4DAdd( + const Array4D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); + + // Performs select and scatter with Greater Than or equal as the select, plus + // as the scatter, and Same Padding. + static std::unique_ptr> SelectAndScatter4DGePlus( + const Array4D& operand, const Array4D& source, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, bool same_padding); + + // Concatenates the lhs and rhs arrays along the concatenate_dimension. + // E.g. if concatenate_dimension is 0, the "n1"/height dimension is + // concatenated, so the arrays are stacked on top of each other. + template + static std::unique_ptr> Concat2D(const Array2D& lhs, + const Array2D& rhs, + int concatenate_dimension) { + CHECK(0 <= concatenate_dimension && concatenate_dimension < 2); + auto result = MakeUnique>( + concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(), + concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2()); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { + for (int64 i1 = 0; i1 < result->n2(); ++i1) { + // If we exceed the bounds of the LHS, draw from the RHS, where the + // result index is adjusted by the number of values present in the LHS. + (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2() + ? lhs(i0, i1) + : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0, + i1 >= lhs.n2() ? i1 - lhs.n2() : i1); + } + } + return result; + } + + // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs + // and rhs must have the same dimensions except for the concatenate dimension. + template + static std::unique_ptr> Concat3D(const Array3D& lhs, + const Array3D& rhs, + int concatenate_dimension) { + CHECK(0 <= concatenate_dimension && concatenate_dimension < 3); + std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()}; + std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + for (int i = 0; i < 3; ++i) { + if (i != concatenate_dimension) { + out_dims[i] = lhs_dims[i]; + CHECK_EQ(lhs_dims[i], rhs_dims[i]); + } else { + out_dims[i] = lhs_dims[i] + rhs_dims[i]; + } + } + auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2]); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { + for (int64 i1 = 0; i1 < result->n2(); ++i1) { + for (int64 i2 = 0; i2 < result->n3(); ++i2) { + (*result)(i0, i1, i2) = + i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() + ? lhs(i0, i1, i2) + : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0, + i1 >= lhs.n2() ? i1 - lhs.n2() : i1, + i2 >= lhs.n3() ? i2 - lhs.n3() : i2); + } + } + } + return result; + } + + // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs + // and rhs must have the same dimensions except for the concatenate dimension. + template + static std::unique_ptr> Concat4D(const Array4D& lhs, + const Array4D& rhs, + int concatenate_dimension) { + CHECK(0 <= concatenate_dimension && concatenate_dimension < 4); + std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; + std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + for (int i = 0; i < 4; ++i) { + if (i != concatenate_dimension) { + out_dims[i] = lhs_dims[i]; + CHECK_EQ(lhs_dims[i], rhs_dims[i]); + } else { + out_dims[i] = lhs_dims[i] + rhs_dims[i]; + } + } + auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2], + out_dims[3]); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { + for (int64 i1 = 0; i1 < result->n2(); ++i1) { + for (int64 i2 = 0; i2 < result->n3(); ++i2) { + for (int64 i3 = 0; i3 < result->n4(); ++i3) { + (*result)(i0, i1, i2, i3) = + i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4() + ? lhs(i0, i1, i2, i3) + : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0, + i1 >= lhs.n2() ? i1 - lhs.n2() : i1, + i2 >= lhs.n3() ? i2 - lhs.n3() : i2, + i3 >= lhs.n4() ? i3 - lhs.n4() : i3); + } + } + } + } + return result; + } + + // Slices with modulo-wrapping. + template + static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, + int64 start, int64 size) { + std::vector result; + for (int64 i = 0; i < size; ++i) { + result.push_back(input[(start + i) % input.size()]); + } + return result; + } + + // Slices the input array given starting indices in each dimension and limit + // indices in each dimension. + template + static std::unique_ptr> Slice2D(const Array2D& input, + std::array starts, + std::array limits) { + CHECK_LE(starts[0], input.n1()); + CHECK_LE(starts[1], input.n2()); + CHECK_LE(limits[0], input.n1()); + CHECK_LE(limits[1], input.n2()); + auto result = + MakeUnique>(limits[0] - starts[0], limits[1] - starts[1]); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { + for (int64 i1 = 0; i1 < result->n2(); ++i1) { + (*result)(i0, i1) = input(starts[0] + i0, starts[1] + i1); + } + } + return result; + } + + template + static std::unique_ptr> Slice4D(const Array4D& input, + std::array starts, + std::array limits) { + CHECK_LE(starts[0], input.n1()); + CHECK_LE(starts[1], input.n2()); + CHECK_LE(starts[2], input.n3()); + CHECK_LE(starts[3], input.n4()); + CHECK_LE(limits[0], input.n1()); + CHECK_LE(limits[1], input.n2()); + CHECK_LE(limits[2], input.n3()); + CHECK_LE(limits[3], input.n4()); + auto result = + MakeUnique>(limits[0] - starts[0], limits[1] - starts[1], + limits[2] - starts[2], limits[3] - starts[3]); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { + for (int64 i1 = 0; i1 < result->n2(); ++i1) { + for (int64 i2 = 0; i2 < result->n3(); ++i2) { + for (int64 i3 = 0; i3 < result->n4(); ++i3) { + (*result)(i0, i1, i2, i3) = input(starts[0] + i0, starts[1] + i1, + starts[2] + i2, starts[3] + i3); + } + } + } + } + return result; + } + + template + static std::unique_ptr> Slice3D(const Array3D& input, + std::array starts, + std::array limits) { + CHECK_LE(starts[0], input.n1()); + CHECK_LE(starts[1], input.n2()); + CHECK_LE(starts[2], input.n3()); + CHECK_LE(limits[0], input.n1()); + CHECK_LE(limits[1], input.n2()); + CHECK_LE(limits[2], input.n3()); + auto result = MakeUnique>( + limits[0] - starts[0], limits[1] - starts[1], limits[2] - starts[2]); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { + for (int64 i1 = 0; i1 < result->n2(); ++i1) { + for (int64 i2 = 0; i2 < result->n3(); ++i2) { + (*result)(i0, i1, i2) = + input(starts[0] + i0, starts[1] + i1, starts[2] + i2); + } + } + } + return result; + } + + // Applies map_function to each element in the input (2D array) and returns + // the result. + // (row, column) index of each element is also provided as arguments to + // map_function. + static std::unique_ptr> MapWithIndexArray2D( + const Array2D& matrix, + const std::function& map_function); + + // Applies map_function to each element in the input (4D array) and returns + // the result. + template + static std::unique_ptr> MapArray4D(const Array4D& input, + F&& map_function) { + return MapWithIndexArray4D(input, + [&](float value, int64, int64, int64, int64) { + return map_function(value); + }); + } + + // Applies map_function to each element in the input (4D array) and returns + // the result. + // (plane, depth, height, width) index of each element is also provided as + // arguments to map_function. + template + static std::unique_ptr> MapWithIndexArray4D( + const Array4D& input, F&& map_function) { + auto result = MakeUnique>(input.planes(), input.depth(), + input.height(), input.width()); + for (int64 plane = 0; plane < input.planes(); ++plane) { + for (int64 depth = 0; depth < input.depth(); ++depth) { + for (int64 height = 0; height < input.height(); ++height) { + for (int64 width = 0; width < input.width(); ++width) { + (*result)(plane, depth, height, width) = + map_function(input(plane, depth, height, width), plane, depth, + height, width); + } + } + } + } + return result; + } + + // Returns the result of a 2D pad on an input matrix. + static std::unique_ptr> PadArray2D( + const Array2D& operand, const PaddingConfig& padding, + const float pad); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_ diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc new file mode 100644 index 0000000000..c53351ca93 --- /dev/null +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -0,0 +1,306 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/reference_util.h" + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +// Tests linear algebra routines implemented in ReferenceUtil class. +// TODO(b/23829238): Currently missing tests for the convolution routine. +class ReferenceUtilTest : public ::testing::Test { + protected: + ReferenceUtilTest() { + matrix_ = MakeUnique>(rows_, cols_); + // [1.f 2.f 3.f] + // [4.f 5.f 6.f] + for (int64 i = 0; i < rows_; ++i) { + for (int64 j = 0; j < cols_; ++j) { + (*matrix_)(i, j) = i * cols_ + j + 1; + } + } + } + + const int64 rows_ = 2; + const int64 cols_ = 3; + std::unique_ptr> matrix_; +}; + +TEST_F(ReferenceUtilTest, TransposeArray2D) { + auto result = ReferenceUtil::TransposeArray2D(*matrix_); + auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, + *result_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, MatmulArray2D) { + Array2D rhs({ + {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, + }); + auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); + auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, + *result_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ReduceToColArray2D) { + auto add = [](float lhs, float rhs) { return lhs + rhs; }; + auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); + auto result_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *result_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { + auto add = [](float lhs, float rhs) { return lhs + rhs; }; + auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); + auto result_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *result_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, MapArray2D) { + auto identity = [](float value) { return log(exp(value)); }; + auto result = ReferenceUtil::MapArray2D(*matrix_, identity); + auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { + auto add_index = [](float value, int64 row, int64 col) { + return value + row + col; + }; + auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); + auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, + *result_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, MapArray4D) { + auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); + input->FillWithMultiples(1.0f); + auto multiply_by_two = [](float value) { return 2 * value; }; + auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); + auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + + Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); + expected.FillWithMultiples(2.0f); + LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { + auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); + input->FillWithMultiples(1.0f); + auto subtract_index = [](float value, int64 plane, int64 depth, int64 height, + int64 width) { + return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); + }; + auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); + auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + + Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); + expected.Fill(0.0f); + LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ConvWithSamePadding) { + Array4D input(1, 1, 4, 4); + // clang-format off + input.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + Array4D weights(1, 1, 2, 2); + // clang-format off + weights.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + std::unique_ptr> actual = + ReferenceUtil::ConvArray4D(input, weights, {1, 1}, Padding::kSame); + Array4D expected(1, 1, 4, 4); + // clang-format off + expected.FillWithYX(Array2D({ + {100, 126, 152, 76}, + {204, 230, 256, 124}, + {308, 334, 360, 172}, + {149, 160, 171, 80}, + })); + // clang-format on + + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ConvWithValidPadding) { + Array4D input(1, 1, 4, 4); + // clang-format off + input.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + Array4D weights(1, 1, 2, 2); + // clang-format off + weights.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + std::unique_ptr> actual = + ReferenceUtil::ConvArray4D(input, weights, {1, 1}, Padding::kValid); + Array4D expected(1, 1, 3, 3); + // clang-format off + expected.FillWithYX(Array2D({ + {1*5+2*6+5*7+6*8, 126, 152}, + {204, 230, 256}, + {308, 334, 11*5+12*6+15*7+16*8}, + })); + // clang-format on + + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, height=3, kernel_input_feature=2, width=3] + Array4D weight({{ + {{1, 2, 3}, + {4, 5, 6}}, + {{7, 8, 9}, + {10, 11, 12}}, + {{13, 14, 15}, + {16, 17, 18}} + }}); + // clang-format on + + // Set the convolution dimension numbers. + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_batch_dimension(2); + dimension_numbers.set_feature_dimension(0); + dimension_numbers.add_spatial_dimensions(1); + dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.set_kernel_output_feature_dimension(0); + dimension_numbers.set_kernel_input_feature_dimension(2); + dimension_numbers.add_kernel_spatial_dimensions(1); + dimension_numbers.add_kernel_spatial_dimensions(3); + + std::unique_ptr> actual = + ReferenceUtil::ConvArray4DGeneralDimensions( + input, weight, {1, 1}, Padding::kSame, dimension_numbers); + // clang-format off + // Result dimensions: [feature=1, height=3, batch=1, width=4] + Array4D expected({{ + {{1110, 1688, 1838, 1226}}, + {{1683, 2514, 2685, 1761}}, + {{878, 1280, 1358, 866}} + }}); + // clang-format on + + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3] + Array4D weight({{ + {{1, 7, 13}, + {4, 10, 16}}, + {{2, 8, 14}, + {5, 11, 17}}, + {{3, 9, 15}, + {6, 12, 18}} + }}); + // clang-format on + + // Set the convolution dimension numbers. + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_batch_dimension(2); + dimension_numbers.set_feature_dimension(0); + dimension_numbers.add_spatial_dimensions(1); + dimension_numbers.add_spatial_dimensions(3); + + dimension_numbers.set_kernel_output_feature_dimension(0); + dimension_numbers.set_kernel_input_feature_dimension(2); + dimension_numbers.add_kernel_spatial_dimensions(3); + dimension_numbers.add_kernel_spatial_dimensions(1); + + std::unique_ptr> actual = + ReferenceUtil::ConvArray4DGeneralDimensions( + input, weight, {1, 1}, Padding::kValid, dimension_numbers); + // clang-format off + // Result dimensions: [feature=1, height=1, batch=1, width=2] + Array4D expected({{{{2514, 2685}}}}); + // clang-format on + + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD new file mode 100644 index 0000000000..fb36443801 --- /dev/null +++ b/tensorflow/compiler/xla/service/BUILD @@ -0,0 +1,1216 @@ +# Description: +# XLA service implementation. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + +xla_proto_library( + name = "session_proto", + srcs = ["session.proto"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/compiler/xla:xla_data_proto"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "shape_inference", + srcs = ["shape_inference.cc"], + hdrs = ["shape_inference.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "shape_inference_test", + srcs = ["shape_inference_test.cc"], + deps = [ + ":shape_inference", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "hlo_opcode_test", + srcs = ["hlo_opcode_test.cc"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo", + srcs = [ + "dfs_hlo_visitor.cc", + "hlo_computation.cc", + "hlo_instruction.cc", + "hlo_module.cc", + "hlo_opcode.cc", + ], + hdrs = [ + "dfs_hlo_visitor.h", + "dfs_hlo_visitor_with_default.h", + "hlo_computation.h", + "hlo_instruction.h", + "hlo_module.h", + "hlo_opcode.h", + ], + deps = [ + ":name_uniquer", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "versioned_computation_handle", + hdrs = ["versioned_computation_handle.h"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_instruction_test", + srcs = ["hlo_instruction_test.cc"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "user_computation", + srcs = ["user_computation.cc"], + hdrs = ["user_computation.h"], + deps = [ + ":hlo", + ":session_proto", + ":shape_inference", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "platform_util", + srcs = ["platform_util.cc"], + hdrs = ["platform_util.h"], + deps = [ + ":compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "backend", + srcs = ["backend.cc"], + hdrs = ["backend.h"], + deps = [ + ":compiler", + ":device_memory_allocator", + ":platform_util", + ":transfer_manager", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:backend_flags", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", + ], +) + +cc_library( + name = "service", + srcs = ["service.cc"], + hdrs = ["service.h"], + deps = [ + ":allocation_tracker", + ":backend", + ":channel_tracker", + ":compilation_cache", + ":compiler", + ":computation_layout", + ":computation_tracker", + ":cpu_transfer_manager", + ":device_memory_allocator", + ":executable", + ":execution_tracker", + ":hlo", + ":hlo_cost_analysis", + ":hlo_execution_profile", + ":hlo_graph_dumper", + ":hlo_module_config", + ":platform_util", + ":session_proto", + ":transfer_manager", + ":user_computation", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:service_interface", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = 1, +) + +cc_library( + name = "local_service", + srcs = ["local_service.cc"], + hdrs = ["local_service.h"], + deps = [ + ":backend", + ":compiler", + ":computation_layout", + ":computation_tracker", + ":device_memory_allocator", + ":executable", + ":hlo", + ":hlo_execution_profile", + ":hlo_module_config", + ":platform_util", + ":service", + ":shaped_buffer", + ":user_computation", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cpu_plugin", + deps = [ + ":cpu_transfer_manager", + ":service", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "gpu_plugin", + deps = [ + ":generic_transfer_manager", + ":service", + "//tensorflow/compiler/xla/service/gpu:gpu_compiler", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + ], +) + +cc_library( + name = "shaped_buffer", + srcs = ["shaped_buffer.cc"], + hdrs = ["shaped_buffer.h"], + deps = [ + ":device_memory_allocator", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "executable", + srcs = ["executable.cc"], + hdrs = ["executable.h"], + deps = [ + ":computation_layout", + ":device_memory_allocator", + ":hlo", + ":hlo_execution_profile", + ":hlo_module_config", + ":session_proto", + ":shaped_buffer", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":executable", + ":hlo", + ":hlo_module_config", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "transfer_manager", + srcs = ["transfer_manager.cc"], + hdrs = ["transfer_manager.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "allocation_tracker", + srcs = ["allocation_tracker.cc"], + hdrs = ["allocation_tracker.h"], + deps = [ + ":backend", + ":device_memory_allocator", + ":transfer_manager", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "execution_tracker", + srcs = ["execution_tracker.cc"], + hdrs = ["execution_tracker.h"], + deps = [ + ":backend", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "computation_tracker", + srcs = ["computation_tracker.cc"], + hdrs = ["computation_tracker.h"], + deps = [ + ":hlo", + ":session_proto", + ":user_computation", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "channel_tracker", + srcs = ["channel_tracker.cc"], + hdrs = ["channel_tracker.h"], + deps = [ + ":hlo", + ":session_proto", + ":user_computation", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "name_uniquer", + srcs = ["name_uniquer.cc"], + hdrs = ["name_uniquer.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "buffer_liveness", + srcs = [ + "buffer_liveness.cc", + ], + hdrs = [ + "buffer_liveness.h", + ], + deps = [ + ":hlo", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "buffer_liveness_test", + srcs = ["buffer_liveness_test.cc"], + deps = [ + ":buffer_liveness", + ":cpu_plugin", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "buffer_assignment", + srcs = [ + "buffer_assignment.cc", + ], + hdrs = [ + "buffer_assignment.h", + ], + deps = [ + ":buffer_liveness", + ":hlo", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "buffer_assignment_test", + srcs = ["buffer_assignment_test.cc"], + deps = [ + ":buffer_assignment", + ":computation_tracker", + ":cpu_plugin", + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_query", + srcs = ["hlo_query.cc"], + hdrs = ["hlo_query.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + ], +) + +cc_library( + name = "instruction_fusion", + srcs = ["instruction_fusion.cc"], + hdrs = ["instruction_fusion.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "instruction_fusion_test", + srcs = ["instruction_fusion_test.cc"], + deps = [ + ":instruction_fusion", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "algebraic_simplifier", + srcs = ["algebraic_simplifier.cc"], + hdrs = ["algebraic_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "algebraic_simplifier_test", + srcs = ["algebraic_simplifier_test.cc"], + deps = [ + ":algebraic_simplifier", + ":cpu_plugin", + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "reshape_mover", + srcs = ["reshape_mover.cc"], + hdrs = ["reshape_mover.h"], + deps = [ + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + ], +) + +cc_test( + name = "reshape_mover_test", + srcs = ["reshape_mover_test.cc"], + deps = [ + ":hlo", + ":reshape_mover", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "inliner", + srcs = ["inliner.cc"], + hdrs = ["inliner.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "inliner_test", + srcs = ["inliner_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + ":inliner", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "generic_transfer_manager", + srcs = ["generic_transfer_manager.cc"], + hdrs = ["generic_transfer_manager.h"], + deps = [ + ":transfer_manager", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + +cc_library( + name = "cpu_transfer_manager", + srcs = ["cpu_transfer_manager.cc"], + hdrs = ["cpu_transfer_manager.h"], + deps = [ + ":generic_transfer_manager", + ":transfer_manager", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + +cc_test( + name = "transfer_manager_test", + srcs = ["transfer_manager_test.cc"], + deps = [ + ":cpu_transfer_manager", + ":generic_transfer_manager", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_cost_analysis", + srcs = [ + "hlo_cost_analysis.cc", + ], + hdrs = [ + "hlo_cost_analysis.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_cost_analysis_test", + srcs = ["hlo_cost_analysis_test.cc"], + deps = [ + ":computation_tracker", + ":hlo", + ":hlo_cost_analysis", + ":local_service", + ":service", + ":user_computation", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_execution_profile", + srcs = ["hlo_execution_profile.cc"], + hdrs = ["hlo_execution_profile.h"], + deps = [ + ":hlo", + ":hlo_cost_analysis", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_test( + name = "hlo_computation_test", + srcs = ["hlo_computation_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_binary( + name = "graphviz_example", + srcs = ["graphviz_example.cc"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_module_test", + srcs = ["hlo_module_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "logical_buffer", + srcs = [ + "logical_buffer.cc", + ], + hdrs = [ + "logical_buffer.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "tuple_points_to_analysis", + srcs = [ + "tuple_points_to_analysis.cc", + ], + hdrs = [ + "tuple_points_to_analysis.h", + ], + deps = [ + ":hlo", + ":logical_buffer", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "tuple_points_to_analysis_test", + srcs = ["tuple_points_to_analysis_test.cc"], + deps = [ + ":hlo", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "compilation_cache", + srcs = [ + "compilation_cache.cc", + ], + hdrs = [ + "compilation_cache.h", + ], + deps = [ + ":executable", + ":hlo_module_config", + ":versioned_computation_handle", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "layout_assignment", + srcs = [ + "layout_assignment.cc", + ], + hdrs = [ + "layout_assignment.h", + ], + deps = [ + ":computation_layout", + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "copy_insertion", + srcs = ["copy_insertion.cc"], + hdrs = ["copy_insertion.h"], + deps = [ + ":buffer_liveness", + ":hlo", + ":hlo_pass", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "copy_insertion_test", + srcs = ["copy_insertion_test.cc"], + deps = [ + ":buffer_liveness", + ":copy_insertion", + ":cpu_plugin", + ":hlo", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_dce", + srcs = ["hlo_dce.cc"], + hdrs = ["hlo_dce.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_dce_test", + srcs = ["hlo_dce_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + ":hlo_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "layout_assignment_test", + srcs = ["layout_assignment_test.cc"], + deps = [ + ":algebraic_simplifier", + ":computation_layout", + ":cpu_plugin", + ":hlo", + ":layout_assignment", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_pass", + hdrs = ["hlo_pass.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_pass_pipeline", + srcs = [ + "hlo_pass_pipeline.cc", + ], + hdrs = [ + "hlo_pass_pipeline.h", + ], + deps = [ + ":compiler", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_cse", + srcs = ["hlo_cse.cc"], + hdrs = ["hlo_cse.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) + +cc_test( + name = "hlo_cse_test", + srcs = ["hlo_cse_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + ":hlo_cse", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "device_memory_allocator", + srcs = ["device_memory_allocator.cc"], + hdrs = ["device_memory_allocator.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "elemental_ir_emitter", + srcs = ["elemental_ir_emitter.cc"], + hdrs = ["elemental_ir_emitter.h"], + deps = [ + ":hlo", + ":hlo_module_config", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@llvm//:core", + "@llvm//:transform_utils", + ], +) + +cc_library( + name = "hlo_module_config", + srcs = ["hlo_module_config.cc"], + hdrs = ["hlo_module_config.h"], + deps = [ + ":computation_layout", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "computation_layout", + srcs = ["computation_layout.cc"], + hdrs = ["computation_layout.h"], + deps = [ + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_subcomputation_unification", + srcs = ["hlo_subcomputation_unification.cc"], + hdrs = ["hlo_subcomputation_unification.h"], + deps = [ + ":hlo_pass", + ], +) + +cc_test( + name = "hlo_subcomputation_unification_test", + srcs = ["hlo_subcomputation_unification_test.cc"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_subcomputation_unification", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_graph_dumper", + srcs = [ + "hlo_graph_dumper.cc", + ], + hdrs = ["hlo_graph_dumper.h"], + deps = [ + ":hlo", + ":hlo_execution_profile", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "transpose_folding", + srcs = ["transpose_folding.cc"], + hdrs = ["transpose_folding.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "transpose_folding_test", + srcs = ["transpose_folding_test.cc"], + deps = [ + ":hlo", + ":transpose_folding", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc new file mode 100644 index 0000000000..fe892e872f --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -0,0 +1,938 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Returns whether operand is a literal with the given value. +bool IsLiteralWithValue(const HloInstruction* operand, int value) { + return operand->opcode() == HloOpcode::kConstant && + LiteralUtil::IsAll(operand->literal(), value); +} + +// Returns whether the given transpose produces a result which is bit-wise +// identical to its operand and thus may be replaced with a bitcast. +bool TransposeIsBitcast( + const HloInstruction* transpose, + const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + CHECK_EQ(HloOpcode::kTranspose, transpose->opcode()); + const HloInstruction* operand = transpose->operand(0); + + // Can't insert bitcasts if the compiler used a memory layout which isn't + // compatible. + if (!valid_bitcast_callback(operand->shape(), transpose->shape())) { + return false; + } + + return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(), + transpose->dimensions()); +} + +// Returns true if the given reshape produces a result which is bit-wise +// identical to its operand and thus may be replaced with a bitcast. +// +// This function is conservative -- even if this function returns false, the +// reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. +bool ReshapeIsBitcast( + const HloInstruction* reshape, + const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); + + const HloInstruction* operand = reshape->operand(0); + // Can't insert bitcasts if the compiler used a memory layout which isn't + // compatible. + if (!valid_bitcast_callback(operand->shape(), reshape->shape())) { + return false; + } + + return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()); +} +} // namespace + +// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain +// algebraic expressions to simplified forms. Note: This only supports +// simplifications that simply look at the operands of an instruction. For the +// more general case a worklist based approach would be needed. +class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleBroadcast(HloInstruction* broadcast) override; + + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override; + + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override; + + Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + + Status HandleLog(HloInstruction* log, HloInstruction* operand) override; + + Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandlePad(HloInstruction* pad) override; + + Status HandlePower(HloInstruction* power, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, + HloInstruction* rhs) override; + + // Returns whether algebraic simplification has occurred. + const bool changed() const { return changed_; } + + // Runs the visitor on a computation. + static bool Run( + HloComputation* computation, bool is_layout_sensitive, + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback); + + private: + explicit AlgebraicSimplifierVisitor( + HloComputation* computation, bool is_layout_sensitive, + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) + : computation_(computation), + is_layout_sensitive_(is_layout_sensitive), + valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + + // Convenience method for replacing an instruction with a bitcast. + void ReplaceWithBitcast(HloInstruction* instruction); + + // Replace old instruction with new instruction if old and new instructions + // have the same shape. Updates uses and root instruction. Returns whether a + // replacement was made. + bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction, + HloInstruction* new_instruction); + + // Returns whether the shape of the output of the given instructions are the + // same for the purposes of simplification. If is_layout_sensitive_ is true, + // then this tests shape equality including layout (ShapeUtil::Equal). If + // is_layout_sensitive_ is false, then the tests shape compatibility + // (ShapeUtil::Compatible). + bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; + + // Returns whether it was possible to transform `root` to a clamp instruction. + // With min a minimum instruction, max a maximum instruction, min_operand a + // operand of min and max_operand a operand of max. + // Precondition: root is either a minimum or a maximum. + bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min, + HloInstruction* min_operand, + HloInstruction* operand, HloInstruction* max, + HloInstruction* max_operand); + + // Current HloComputation instance the AlgebraicSimplifierVisitor is + // traversing. + HloComputation* computation_; + + // Whether algebraic simplification has occurred. + bool changed_ = false; + + // Whether layout is considered during transformation. + bool is_layout_sensitive_; + + // Callback used to determine if a bitcast is valid. + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; +}; + +bool AlgebraicSimplifierVisitor::Run( + HloComputation* computation, bool is_layout_sensitive, + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) { + AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive, + std::move(valid_bitcast_callback)); + TF_CHECK_OK(computation->root_instruction()->Accept(&visitor)); + return visitor.changed_; +} + +bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, + const HloInstruction* rhs) const { + if (is_layout_sensitive_) { + return ShapeUtil::Equal(lhs->shape(), rhs->shape()); + } else { + return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); + } +} + +void AlgebraicSimplifierVisitor::ReplaceWithBitcast( + HloInstruction* instruction) { + CHECK_EQ(1, instruction->operand_count()); + CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), + ShapeUtil::ElementsIn(instruction->operand(0)->shape())); + CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape())); + + auto bitcast = computation_->AddInstruction( + HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, + instruction->mutable_operand(0))); + computation_->ReplaceInstruction(instruction, bitcast); + changed_ = true; +} + +bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( + HloInstruction* old_instruction, HloInstruction* new_instruction) { + if (!SameShape(old_instruction, new_instruction)) { + return false; + } + computation_->ReplaceInstruction(old_instruction, new_instruction); + changed_ = true; + return true; +} + +Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, + HloInstruction* lhs, + HloInstruction* rhs) { + // A + 0 => A + VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); + if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { + return Status::OK(); + } + // 0 + A => A + VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); + if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, + HloInstruction* operand) { + // All copies can be eliminated (assuming layout constraints are satisified). + ReplaceInstructionIfSameShape(copy, operand); + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, + HloInstruction* lhs, + HloInstruction* rhs) { + // A - 0 => A + VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); + if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, + HloInstruction* lhs, + HloInstruction* rhs) { + // A/1 => A + VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); + if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + return Status::OK(); + } + + // exp(A)/exp(B) => exp(A-B) + if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); + HloInstruction* subtract = + computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), + rhs->mutable_operand(0))); + computation_->ReplaceWithNewInstruction( + divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, + subtract)); + changed_ = true; + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, + HloInstruction* lhs, + HloInstruction* rhs) { + // A*1 => A + VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); + if (IsLiteralWithValue(rhs, 1) && + ReplaceInstructionIfSameShape(multiply, lhs)) { + return Status::OK(); + } + // 1*A => A + VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); + if (IsLiteralWithValue(lhs, 1) && + ReplaceInstructionIfSameShape(multiply, rhs)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log, + HloInstruction* operand) { + // ln(exp(A)) => A + VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); + if (operand->opcode() == HloOpcode::kExp && + ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleGetTupleElement( + HloInstruction* get_tuple_element, HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kTuple) { + // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i + VLOG(10) << "trying transform " + << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: " + << get_tuple_element->ToString(); + if (ReplaceInstructionIfSameShape( + get_tuple_element, + operand->mutable_operand(get_tuple_element->tuple_index()))) { + return Status::OK(); + } + } + return Status::OK(); +} + +namespace { + +// Return whether the given reshape instruction leaves the dimensions at the +// given input indices unmodified, and returns their output indices. +// +// Example: +// input_dim_indices = {2, 3} +// input shape = T[a, b, x, y, cd] +// output shape = T[ab, x, 1, y, c, d] +// return value = {1, 3} +// +// Precondition: input_dim_indices is sorted. +std::pair> ReshapeLeavesDimensionsUnmodified( + const HloInstruction* hlo, + tensorflow::gtl::ArraySlice input_dim_indices) { + CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); + CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); + + std::vector output_dim_indices; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(), + hlo->shape()); + size_t i = 0; // index to unmodified_dims + for (int64 input_dim_index : input_dim_indices) { + // Search unmodified_dims for input_dim_index. We can search from the last + // matching position because input_dim_indices is guaranteed to be sorted. + while (i < unmodified_dims.size() && + unmodified_dims[i].first < input_dim_index) { + ++i; + } + if (i >= unmodified_dims.size() || + unmodified_dims[i].first != input_dim_index) { + return std::make_pair(false, std::vector()); + } + output_dim_indices.push_back(unmodified_dims[i].second); + } + return std::make_pair(true, output_dim_indices); +} + +// Returns true if the output of "instruction" is a permutation of the elements +// of "operand". Precondition: "operand" is an operand of "instruction". +bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, + HloInstruction* operand) { + DCHECK(!instruction->OperandIndices(operand).empty()); + switch (instruction->opcode()) { + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + +// Returns true if the output of "instruction" is a subset of the elements of +// "operand". Precondition: "operand" is an operand of "instruction". +bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, + HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + CHECK(!operand_indices.empty()); + if (operand_indices.size() != 1) { + return false; + } + int64 operand_index = operand_indices[0]; + switch (instruction->opcode()) { + case HloOpcode::kSlice: + CHECK_EQ(0, operand_index); + return true; + case HloOpcode::kDynamicSlice: + return operand_index == 0; + default: + return false; + } +} + +} // namespace + +Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { + auto operand = broadcast->mutable_operand(0); + // A degenerate broadcast of a reshape that does not change the number of + // elements can be replaced by a reshape. + if (std::is_sorted(broadcast->dimensions().begin(), + broadcast->dimensions().end()) && + ShapeUtil::ElementsIn(broadcast->shape()) == + ShapeUtil::ElementsIn(operand->shape())) { + VLOG(10) << "transform broadcast(X) -> reshape(X) where " + "n(broadcast(X)) == n(X)"; + computation_->ReplaceWithNewInstruction( + broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); + changed_ = true; + return Status::OK(); + } + + // A broadcast of a reshape which merely inserts 1-sized dimensions can elide + // its operand. + { + bool merely_inserts_or_deletes_1_sized_dimensions; + std::vector inserted_indices, deleted_indices; + std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices, + inserted_indices) = + operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); + if (merely_inserts_or_deletes_1_sized_dimensions && + deleted_indices.empty()) { + std::reverse(inserted_indices.begin(), inserted_indices.end()); + auto dims = broadcast->dimensions(); + for (auto inserted_index : inserted_indices) { + dims.erase(dims.begin() + inserted_index); + } + computation_->ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast(broadcast->shape(), + operand->mutable_operand(0), dims)); + changed_ = true; + return Status::OK(); + } + } + + // A scalar broadcast feeding an instruction which only permutes (reshape, + // transpose, sort, reverse) or selects a subset of operand elements (slice, + // dynamic slice) can be replaced with a broadcast directly to the output + // shape of the instruction. + if (ShapeUtil::IsScalar(operand->shape())) { + for (HloInstruction* user : broadcast->users()) { + if (OutputIsPermutationOfOperandElements(user, broadcast) || + OutputIsSubsetOfOperandElements(user, broadcast)) { + HloInstruction* new_broadcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(user->shape(), operand, {})); + // Use ReplaceUsesOfInstruction instead of ReplaceWithNewInstruction + // because we are replacing an instruction other than the visited + // instruction. + computation_->ReplaceUsesOfInstruction(user, new_broadcast); + changed_ = true; + return Status::OK(); + } + } + } + return Status::OK(); +} + +template +static std::unique_ptr ConvertIfTypesMatch( + const Literal& src_literal) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + + return HloInstruction::CreateConstant( + LiteralUtil::Convert::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal)); +} + +template +static std::unique_ptr ConvertIfDestTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +static std::unique_ptr ConvertIfSrcTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (src_literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +// A conversion to the same element type as the operand is a nop and can be +// removed. A conversion of a constant can be simplified by making a new +// constant. +Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, + HloInstruction* operand) { + PrimitiveType src_type = operand->shape().element_type(); + PrimitiveType dest_type = convert->shape().element_type(); + if (src_type == dest_type) { + computation_->ReplaceInstruction(convert, operand); + changed_ = true; + return Status::OK(); + } + if (operand->opcode() == HloOpcode::kConstant) { + const Literal& src_literal = operand->literal(); + std::unique_ptr new_constant = + ConvertIfSrcTypeMatches(src_literal, dest_type); + computation_->ReplaceWithNewInstruction(convert, std::move(new_constant)); + changed_ = true; + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { + // The pad instruction does nothing if the output shape is the same as the + // input shape, i.e, all paddings are zero. + ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, + HloInstruction* lhs, + HloInstruction* rhs) { + VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); + if (IsLiteralWithValue(rhs, 0)) { + auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( + LiteralUtil::One(power->shape().element_type()))); + std::unique_ptr ones; + if (ShapeUtil::IsScalar(power->shape())) { + ones = std::move(one); + } else { + ones = HloInstruction::CreateBroadcast( + power->shape(), computation_->AddInstruction(std::move(one)), {}); + } + computation_->ReplaceWithNewInstruction(power, std::move(ones)); + changed_ = true; + return Status::OK(); + } + + VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); + if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { + return Status::OK(); + } + + VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); + if (IsLiteralWithValue(rhs, 2)) { + computation_->ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), + HloOpcode::kMultiply, lhs, lhs)); + changed_ = true; + return Status::OK(); + } + + VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); + if (IsLiteralWithValue(rhs, -1)) { + auto* one = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( + LiteralUtil::One(rhs->shape().element_type())))); + computation_->ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, + one, lhs)); + changed_ = true; + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { + auto operand = reshape->mutable_operand(0); + + // Delete no-op reshapes, i.e. where shape = operand shape. + if (SameShape(reshape, operand)) { + VLOG(10) << "deleting no-op reshape"; + computation_->ReplaceInstruction(reshape, operand); + changed_ = true; + return Status::OK(); + } + + // Merge reshapes. + if (HloOpcode::kReshape == operand->opcode()) { + computation_->ReplaceWithNewInstruction( + reshape, HloInstruction::CreateReshape(reshape->shape(), + operand->mutable_operand(0))); + changed_ = true; + return Status::OK(); + } + + if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { + auto opt_dims = ReshapeLeavesDimensionsUnmodified( + reshape, reshape->operand(0)->dimensions()); + if (opt_dims.first) { + computation_->ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateBroadcast( + reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), + opt_dims.second)); + changed_ = true; + return Status::OK(); + } + } + + // Make this a bitcast if possible. + if (is_layout_sensitive_ && + ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { + ReplaceWithBitcast(reshape); + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + // Delete no-op slices, i.e. where shape = operand shape. + if (ReplaceInstructionIfSameShape(slice, operand)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleReduce( + HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + auto reshape = computation_->AddInstruction( + HloInstruction::CreateReshape(reduce->shape(), arg)); + computation_->ReplaceWithNewInstruction( + reduce, HloInstruction::CreateMap(reduce->shape(), + {reshape, init_value}, function)); + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { + auto operand = transpose->mutable_operand(0); + + if (std::is_sorted(transpose->dimensions().begin(), + transpose->dimensions().end())) { + VLOG(10) << "deleting no-op transpose"; + computation_->ReplaceInstruction(transpose, operand); + changed_ = true; + return Status::OK(); + } + + if (HloOpcode::kTranspose == operand->opcode()) { + computation_->ReplaceWithNewInstruction( + transpose, HloInstruction::CreateTranspose( + transpose->shape(), operand->mutable_operand(0), + ComposePermutations(operand->dimensions(), + transpose->dimensions()))); + changed_ = true; + return Status::OK(); + } + + if (is_layout_sensitive_ && + TransposeIsBitcast(transpose, valid_bitcast_callback_)) { + ReplaceWithBitcast(transpose); + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, + const Window& window) { + // HandleConvolution tries to replace a convolution with a DOT instruction. + // + // Only add when bitcasts can be used: + // - if bitcasts are not supported, then reshapes could be used but will + // end up with another copy. + // - if bitcasts are supported, the simplifier will be called again with + // bitcasts_ == true. + + // TODO(cwhipkey): b/31337498, make this layout insensitive. + if (!is_layout_sensitive_) return Status::OK(); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + const Shape& input_shape = lhs->shape(); + const Shape& filter_shape = rhs->shape(); + const Shape& convolution_shape = convolution->shape(); + TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape)); + + // Require 1x1 filter in the spatial dimensions (so no need to extract image + // patches). + if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(0)) != 1 || + filter_shape.dimensions(dnums.kernel_spatial_dimensions(1)) != 1) { + return Status::OK(); + } + + // Stride ignores part of the output, which matrix multiplication does not do, + // so require no stride. Padding and base (lhs) dilation both implicitly + // extend the data, which matrix multiplication also does not do, so require + // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect + // for a 1x1 window, so window dilation is no problem. + if (window_util::HasStride(window) || window_util::HasPadding(window) || + window_util::HasBaseDilation(window)) { + return Status::OK(); + } + + // Also, the shapes must align for a rowmajor matmul: + // - the input and output have the same layout. + // - for input/output, the channel dimension must be the most minor. Other + // spatial dims can be in any order. + // - for filters, the input channel dimension must be more major than the + // output channel dimension. The width+height don't matter because + // they are 1. + // + // These constraints are harsh. If the channel dimension is the most major + // and/or the layout of input/output feature dimensions are reversed, we can + // still convert Conv into more efficient Matmul with operand transposition + // (such as the transposition flags in cuBLAS SGEMM). + if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || + input_shape.layout().minor_to_major(0) != dnums.feature_dimension() || + // The input feature dimension should come later in the minor-to-major + // order. + (PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()), + dnums.kernel_input_feature_dimension()) < + PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()), + dnums.kernel_output_feature_dimension()))) { + return Status::OK(); + } + + auto add_bitcast = [&](Shape shape, HloInstruction* operand) { + std::vector dims(operand->shape().dimensions_size()); + std::iota(dims.begin(), dims.end(), 0); + return computation_->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand)); + }; + + // Replace it with a dot, with bitcasts around it to get the right shape. + const int64 input_channels = + input_shape.dimensions(dnums.feature_dimension()); + const int64 output_channels = + filter_shape.dimensions(dnums.kernel_output_feature_dimension()); + + // Computes the product of the non-feature dimensions. + int64 conv_width = 1; + for (int i = 0; i < input_shape.dimensions_size(); ++i) { + if (i != dnums.feature_dimension()) { + conv_width *= input_shape.dimensions(i); + } + } + + // We already checked feature_dimension is most minor, so data in input_shape + // and row-major {conv_width,input_channels} are bitwise identical. + const Shape new_input_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), {conv_width, input_channels}); + // We already checked input_feature_dimension is more major than + // output_feature_dimension, so data in filter_shape and row-major + // {input_channels,output_channels} are bitwise identical. + const Shape new_filter_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + filter_shape.element_type(), {input_channels, output_channels}); + const Shape dot_output_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + convolution_shape.element_type(), {conv_width, output_channels}); + + // We cannot insert bitcasts if the layouts will not be compatible. + // TODO(b/33178038): Consider inserting a transpose if a bitcast would be + // invalid. + if (!valid_bitcast_callback_(lhs->shape(), input_shape) || + !valid_bitcast_callback_(rhs->shape(), new_filter_shape) || + !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + return Status::OK(); + } + + auto new_lhs = add_bitcast(new_input_shape, lhs); + auto new_rhs = add_bitcast(new_filter_shape, rhs); + auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( + dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + computation_->ReplaceInstruction(convolution, + add_bitcast(convolution_shape, dot)); + changed_ = true; + return Status::OK(); +} + +bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( + HloInstruction* root, HloInstruction* min, HloInstruction* min_operand, + HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) { + // Ensure shapes of min and max operand are equal to match current shape + // inference. + if (!SameShape(min_operand, max_operand)) { + return false; + } + + auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, + max_operand, operand, min_operand); + computation_->ReplaceWithNewInstruction(root, std::move(clamp)); + changed_ = true; + return true; +} + +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, + HloInstruction* lhs, + HloInstruction* rhs) { + // Match the following tree: + // min_operand operand + // \ / + // max_operand min + // \ / + // max + // where max_operand and min_operand are scalar constants. + { + HloInstruction* min; + HloInstruction* max_operand; + HloInstruction* min_operand; + HloInstruction* operand; + + if (hlo_query::MatchBinaryInstructionOperandOpcode( + HloOpcode::kMinimum, maximum, + /*matching_operand=*/&min, + /*other_operand=*/&max_operand) && + hlo_query::MatchBinaryInstructionOperand( + hlo_query::IsScalarConstant, min, + /*matching_operand=*/&min_operand, + /*other_operand=*/&operand) && + TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, + max_operand)) { + return Status::OK(); + } + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, + HloInstruction* lhs, + HloInstruction* rhs) { + // Match the following tree: + // max_operand operand + // \ / + // min_operand max + // \ / + // min + // where max_operand and min_operand are scalar constants. + { + HloInstruction* max; + HloInstruction* max_operand; + HloInstruction* min_operand; + HloInstruction* operand; + + if (hlo_query::MatchBinaryInstructionOperandOpcode( + HloOpcode::kMaximum, minimum, + /*matching_operand=*/&max, + /*other_operand=*/&min_operand) && + hlo_query::MatchBinaryInstructionOperand( + hlo_query::IsScalarConstant, max, + /*matching_operand=*/&max_operand, + /*other_operand=*/&operand) && + TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, + max_operand)) { + return Status::OK(); + } + } + + return Status::OK(); +} + +StatusOr AlgebraicSimplifier::Run(HloModule* module) { + return std::any_of( + module->computations().begin(), module->computations().end(), + [=](const std::unique_ptr& computation) { + return AlgebraicSimplifierVisitor::Run( + computation.get(), is_layout_sensitive_, valid_bitcast_callback_); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h new file mode 100644 index 0000000000..4aa06cab53 --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// A pass which performs AlgebraicSimplications. +class AlgebraicSimplifier : public HloPass { + public: + // Given two shapes, determines if it is valid to bitcast between them. + // Precondition: the two shapes have layouts and have the same number of + // elements. + using ValidBitcastCallback = std::function; + + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. If valid_bitcast_callback + // returns true, then the pass will replace reshapes and tranposes with + // bitcasts. + AlgebraicSimplifier(bool is_layout_sensitive, + ValidBitcastCallback valid_bitcast_callback) + : HloPass("algsimp"), + is_layout_sensitive_(is_layout_sensitive), + valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + ~AlgebraicSimplifier() override {} + + // Run algebraic simplification on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + bool is_layout_sensitive_; + ValidBitcastCallback valid_bitcast_callback_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc new file mode 100644 index 0000000000..49ea91f83b --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -0,0 +1,1368 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { + return [](const Shape&, const Shape&) { return true; }; +} +AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { + return [](const Shape&, const Shape&) { return false; }; +} + +using AlgebraicSimplifierTest = HloTestBase; + +// Test that A + 0 is simplified to A +TEST_F(AlgebraicSimplifierTest, AddZero) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A - 0 is simplified to A +TEST_F(AlgebraicSimplifierTest, SubZero) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A/1 is simplified to A for a scalar. +TEST_F(AlgebraicSimplifierTest, DivOneScalar) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, div); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A/1 is simplified to A for an array. +TEST_F(AlgebraicSimplifierTest, DivOneArray) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, div); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that get_element(make_tuple({A,B}),1) is simplified to B +TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + HloInstruction* get = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r0f32, tuple, 1)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, add); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, add); + EXPECT_EQ(root->operand(0), param1); + EXPECT_EQ(root->operand(1), param2); +} + +// Test that exp(A)/exp(B) is simplified to exp(A-B) +TEST_F(AlgebraicSimplifierTest, ExpDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kDivide); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kExp); + EXPECT_EQ(root->operand_count(), 1); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kSubtract); + EXPECT_EQ(root->operand(0)->operand(0), param0); + EXPECT_EQ(root->operand(0)->operand(1), param1); +} + +// Test that ln(exp(A)) is simplified to A +TEST_F(AlgebraicSimplifierTest, LnExp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kLog); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kParameter); + EXPECT_EQ(root, param0); +} + +// Test that ln(exp(A)/exp(B)) is simplified to A-B +TEST_F(AlgebraicSimplifierTest, LnExpDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kLog); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); + EXPECT_EQ(root->operand(0), param0); + EXPECT_EQ(root->operand(1), param1); +} + +// Test that pow(A, 0) where A is a scalar is simplified to the scalar +// constant 1. +TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConstant); + EXPECT_EQ(LiteralUtil::GetFirstElement(root->literal()), 1); +} + +// Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). +TEST_F(AlgebraicSimplifierTest, Pow0Vector) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) + << ShapeUtil::HumanString(root->shape()); + EXPECT_EQ(root->dimensions().size(), 0); + EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape())); + EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), + 1); +} + +// Test that pow(A, 1) is simplified to A. +TEST_F(AlgebraicSimplifierTest, Pow1) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kParameter); + EXPECT_EQ(root, param0); +} + +// Test that pow(A, 2) is simplified to A*A. +TEST_F(AlgebraicSimplifierTest, Pow2) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* two = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); + EXPECT_EQ(root->operand(0), param0); + EXPECT_EQ(root->operand(1), param0); +} + +// Test that pow(A, -1) is simplified to 1/A. +TEST_F(AlgebraicSimplifierTest, PowNegative1) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* negative_one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); + builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, + param0, negative_one)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kDivide); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConstant); + EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), + 1); + EXPECT_EQ(root->operand(1), param0); +} + +TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto builder = HloComputation::Builder(TestName()); + auto op = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 2}), "op")); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op)); + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1})); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); + + auto computation = builder.Build(); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(std::move(computation)); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kParameter); +} + +// Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. +TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); +} + +TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42); +} + +TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42.0f); +} + +TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {0}), + 42); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {1}), + 19); +} + +// Test that copies are removed. +TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* copy = builder.AddInstruction( + HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(copy, computation->root_instruction()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(param0, computation->root_instruction()); +} + +// Test that a simplification which changes layouts is not performed if layout +// sensitive is true. +TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); + HloInstruction* copy = builder.AddInstruction( + HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + // Set to different layouts. + *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + + EXPECT_EQ(copy, computation->root_instruction()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + // Copy has not been removed. + EXPECT_EQ(copy, computation->root_instruction()); +} + +// Test that a simplification which preserves layouts is performed if layout +// sensitive is true. +TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); + HloInstruction* copy = builder.AddInstruction( + HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + // Set to same layouts. + *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + EXPECT_EQ(copy, computation->root_instruction()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + // Copy has been removed. + EXPECT_EQ(param0, computation->root_instruction()); +} + +// Test that a reshape which could be replaced with a bitcast is not if +// add_bitcasts is false. +TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0)); + + *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *reshape->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(reshape, computation->root_instruction()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + // Reshape is not replaced with a bitcast. + EXPECT_EQ(reshape, computation->root_instruction()); +} + +// Test transforming reshapes to bitcasts under various conditions. +TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); + *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + // Reshape which can be transformed into a bitcast. + HloInstruction* transformable_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0)); + *transformable_reshape->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); + + // Reshape does not just add degenerate dimensions. + HloInstruction* dimensions_wrong_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0)); + *dimensions_wrong_reshape->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); + + // Reshape has wrong layout. + HloInstruction* layout_wrong_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0)); + *layout_wrong_reshape->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0}); + + // Collect all the reshapes into a tuple so they are not dead. + builder.AddInstruction(HloInstruction::CreateTuple( + {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(transformable_reshape, computation->root_instruction()->operand(0)); + EXPECT_EQ(dimensions_wrong_reshape, + computation->root_instruction()->operand(1)); + EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + // Verify that only the first reshape is replaced. + EXPECT_NE(transformable_reshape, computation->root_instruction()->operand(0)); + EXPECT_EQ(HloOpcode::kBitcast, + computation->root_instruction()->operand(0)->opcode()); + EXPECT_EQ(dimensions_wrong_reshape, + computation->root_instruction()->operand(1)); + EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2)); +} + +TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param")); + *param->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({1, 2, 0, 3}); + + HloInstruction* transpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3})); + *transpose->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3}); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + // Verify that the reshape is replaced. + EXPECT_EQ(2, computation->instruction_count()); + EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode()); +} + +TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param")); + *param->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({1, 2, 3, 0}); + + HloInstruction* transpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1})); + *transpose->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({3, 1, 2, 0}); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + // Verify that the reshape is replaced. + EXPECT_EQ(2, computation->instruction_count()); + EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode()); +} + +TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); + + HloInstruction* reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {2, 1, 2}), param0)); + + HloInstruction* reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(reshape2, computation->root_instruction()); + EXPECT_EQ(reshape1, computation->root_instruction()->operand(0)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode()); + EXPECT_EQ(HloOpcode::kParameter, + computation->root_instruction()->operand(0)->opcode()); +} + +TEST_F(AlgebraicSimplifierTest, TransposesMerged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0")); + + HloInstruction* transpose1 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0})); + + HloInstruction* transpose2 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(transpose2, computation->root_instruction()); + EXPECT_EQ(transpose1, computation->root_instruction()->operand(0)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kTranspose, computation->root_instruction()->opcode()); + EXPECT_EQ(std::vector({2, 1, 0}), + computation->root_instruction()->dimensions()); + EXPECT_EQ(HloOpcode::kParameter, + computation->root_instruction()->operand(0)->opcode()); +} + +// Test merging reshape and broadcast. +TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5}), "param0")); + auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 5, 1}), param0)); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); + EXPECT_EQ(HloOpcode::kParameter, + computation->root_instruction()->operand(0)->opcode()); +} + +// Test merging broadcast and reshape. +TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0")); + auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2})); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); + EXPECT_EQ(HloOpcode::kParameter, + computation->root_instruction()->operand(0)->opcode()); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { + HloComputation::Builder builder(TestName()); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1}), "param")); + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 1}), param, {1})); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { + HloComputation::Builder builder(TestName()); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {4}), "param")); + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2})); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); + + auto module = MakeUnique(TestName()); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); + EXPECT_MATCH(computation->root_instruction()->dimensions(), + testing::VectorMatcher({3})); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { + HloComputation::Builder builder(TestName()); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1}), "param")); + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2})); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); + + auto module = MakeUnique(TestName()); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); + const std::vector broadcast_dims = + computation->root_instruction()->dimensions(); + EXPECT_EQ(1, broadcast_dims.size()); + EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 || + broadcast_dims[3] == 3); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { + HloComputation::Builder builder(TestName()); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {4}), "param")); + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2})); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + PaddingConfig no_padding; + for (auto i = 0; i < 2; ++i) { + auto dimension = no_padding.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + } + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); + + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + EXPECT_EQ(1, computation->instruction_count()); +} + +TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 3}), "param")); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); + + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + EXPECT_EQ(1, computation->instruction_count()); +} + +TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 2; + const int64 dim1 = 3; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, + /*limit_indices=*/{dim0, dim1})); + + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + EXPECT_EQ(1, computation->instruction_count()); +} + +TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { + struct ConvTestOptions { + int in_batch = 10; + int in_height = 2; + int in_width = 2; + int in_channels = 3; + int f_width = 1; + int f_height = 1; + int f_output_channels = 10; + int row_stride = 1; + int row_padding = 0; + int col_stride = 1; + int col_padding = 0; + bool input_minor_to_major_layout = false; + bool filter_minor_to_major_layout = false; + bool output_minor_to_major_layout = false; + + const char* dim_order = "NHWC"; // can use chars NHWC in any order. + const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order. + + ConvTestOptions& Reset() { + *this = ConvTestOptions(); + return *this; + } + }; + + ConvTestOptions options; + + // Builds a convolution from and runs algebraic simplification on + // the computation. Returns a string description of the result of + // simplification. + auto build_and_simplify = [&options, this]() -> string { + HloComputation::Builder b(TestName()); + + Window window; + auto* f_dim_1 = window.add_dimensions(); + f_dim_1->set_size(options.f_height); + f_dim_1->set_stride(options.row_stride); + f_dim_1->set_padding_low(options.row_padding); + f_dim_1->set_padding_high(options.row_padding); + f_dim_1->set_window_dilation(1); + f_dim_1->set_base_dilation(1); + auto* f_dim_2 = window.add_dimensions(); + f_dim_2->set_size(options.f_width); + f_dim_2->set_stride(options.col_stride); + f_dim_2->set_padding_low(options.col_padding); + f_dim_2->set_padding_high(options.col_padding); + f_dim_2->set_window_dilation(1); + f_dim_2->set_base_dilation(1); + + ConvolutionDimensionNumbers dnums; + std::vector in_dims; + int in_channel_idx = -1; + dnums.add_spatial_dimensions(-1); // filled in later + dnums.add_spatial_dimensions(-1); // filled in later + for (int i = 0; i < strlen(options.dim_order); ++i) { + char ch = options.dim_order[i]; + if (ch == 'N') { + dnums.set_batch_dimension(i); + in_dims.push_back(options.in_batch); + } else if (ch == 'H') { + dnums.set_spatial_dimensions(0, i); + in_dims.push_back(options.in_height); + } else if (ch == 'W') { + dnums.set_spatial_dimensions(1, i); + in_dims.push_back(options.in_width); + } else if (ch == 'C') { + dnums.set_feature_dimension(i); + in_dims.push_back(options.in_channels); + in_channel_idx = i; + } + } + + std::vector f_dims; + dnums.add_kernel_spatial_dimensions(-1); // filled in later + dnums.add_kernel_spatial_dimensions(-1); // filled in later + for (int i = 0; i < strlen(options.kernel_dim_order); ++i) { + char ch = options.kernel_dim_order[i]; + if (ch == 'H') { + dnums.set_kernel_spatial_dimensions(0, i); + f_dims.push_back(options.f_height); + } else if (ch == 'W') { + dnums.set_kernel_spatial_dimensions(1, i); + f_dims.push_back(options.f_width); + } else if (ch == 'I') { + dnums.set_kernel_input_feature_dimension(i); + f_dims.push_back(options.in_channels); + } else if (ch == 'O') { + dnums.set_kernel_output_feature_dimension(i); + f_dims.push_back(options.f_output_channels); + } + } + + auto out_dims = in_dims; + out_dims[in_channel_idx] = options.f_output_channels; + + auto make_shape = [](tensorflow::gtl::ArraySlice dims, + bool minor_to_major_layout) { + if (minor_to_major_layout) { + return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3}); + } else { + return ShapeUtil::MakeShape(F32, dims); + } + }; + auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout); + auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout); + auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout); + + HloInstruction* input = + b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input")); + HloInstruction* filter = + b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); + + b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, + window, dnums)); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(b.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + bitcasting_callback()); + if (!simplifier.Run(&module).ValueOrDie()) { + return "NO_CHANGE"; + } + auto* root = computation->root_instruction(); + if (root->opcode() == HloOpcode::kBitcast && + root->operand(0)->opcode() == HloOpcode::kDot) { + auto lhs_shape = root->operand(0)->operand(0)->shape(); + auto rhs_shape = root->operand(0)->operand(1)->shape(); + return tensorflow::strings::StrCat( + tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", + tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); + } + return "UNEXPECTED CHANGE"; + }; + + // Default options are the simplest case and succeed. + options.Reset(); + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + + // Swapping dim spatial and batch order works. + options.Reset().dim_order = "NWHC"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + options.Reset().dim_order = "WHNC"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + // Channel dimension earlier fails. + options.Reset().dim_order = "HWCN"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().dim_order = "CHWN"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + + // Filtering dims spatial dims can be anywhere, since they are 1x1. + options.Reset().kernel_dim_order = "WHIO"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + options.Reset().kernel_dim_order = "IWOH"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + options.Reset().kernel_dim_order = "IWHO"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + // But moving output channel before input channel fails. + options.Reset().kernel_dim_order = "HWOI"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().kernel_dim_order = "WHOI"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().kernel_dim_order = "OWIH"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().kernel_dim_order = "OWHI"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + + // Combine different dim and kernel dim orders. + options.Reset().kernel_dim_order = "IWHO"; + options.dim_order = "WHNC"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + + // Test invalid cases from wrong filter size, strides, or padding. + options.Reset().f_width = 2; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().f_height = 2; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().row_stride = 2; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().col_stride = 2; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().col_padding = 1; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + options.Reset().row_padding = 1; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + + // The default dim_order is "NHWC". Col-major layout makes C the most major. + options.Reset().input_minor_to_major_layout = true; + options.output_minor_to_major_layout = true; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + + // The input and output have different layouts. + options.Reset().input_minor_to_major_layout = true; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + + // C is most minor, and I is more major than O. + options.Reset().input_minor_to_major_layout = true; + options.filter_minor_to_major_layout = true; + options.output_minor_to_major_layout = true; + options.dim_order = "CHWN"; + options.kernel_dim_order = "OIHW"; + EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); + + // C is not the most minor dimension. + options.Reset().input_minor_to_major_layout = true; + options.filter_minor_to_major_layout = true; + options.output_minor_to_major_layout = true; + options.dim_order = "HWNC"; + options.kernel_dim_order = "OIHW"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); + + // I is more minor than O. + options.Reset().input_minor_to_major_layout = true; + options.filter_minor_to_major_layout = true; + options.output_minor_to_major_layout = true; + options.dim_order = "CHWN"; + options.kernel_dim_order = "IOHW"; + EXPECT_EQ("NO_CHANGE", build_and_simplify()); +} + +// Test that max(min(A, x), y) is transformed to clamp(y, A, x) +TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* min_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* max_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kMinimum, param0, min_value)); + HloInstruction* max = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, max); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + root = computation->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kClamp); + EXPECT_EQ(root->operand(0), max_value); + EXPECT_EQ(root->operand(1), param0); + EXPECT_EQ(root->operand(2), min_value); +} + +// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar +// values. +TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* min_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* max_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kMaximum, param0, max_value)); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, min); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kClamp); + EXPECT_EQ(root->operand(0), max_value); + EXPECT_EQ(root->operand(1), param0); + EXPECT_EQ(root->operand(2), min_value); +} + +// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for +// broadcasted scalar values. +TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* min_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* max_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kMaximum, param0, max_value)); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, min); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kClamp); + EXPECT_EQ(root->operand(0), max_value); + EXPECT_EQ(root->operand(1), param0); + EXPECT_EQ(root->operand(2), min_value); +} + +// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to +// clamp(non-constant1, A, non-constant2) +TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* min_value = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* max_value = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kMaximum, param0, max_value)); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, min); +} + +// Test that min(f(max(A, constant1)), constant2) is not transformed to +// clamp(constant1, A, constant2) +TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* min_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* max_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kMaximum, param0, max_value)); + HloInstruction* fmax = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); + HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kMinimum, fmax, min_value)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, min); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, min); +} + +// Test that slice(broadcast(/*scalar value*/)) simplifies to a single +// broadcast. +TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "scalar_param")); + + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, scalar_param, + AsInt64Slice(broadcast_shape.dimensions()))); + + Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); + HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( + slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, slice); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(scalar_param, root->operand(0)); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); +} + +// Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a +// single broadcast. +TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { + HloComputation::Builder builder(TestName()); + HloInstruction* forty_two = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, forty_two, + AsInt64Slice(broadcast_shape.dimensions()))); + + HloInstruction* transpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0})); + + Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4}); + HloInstruction* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(reshape_shape, transpose)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reshape); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(forty_two, root->operand(0)); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc new file mode 100644 index 0000000000..a123213401 --- /dev/null +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/allocation_tracker.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +AllocationTracker::AllocationTracker() : next_handle_(1) {} + +GlobalDataHandle AllocationTracker::Register(Backend* backend, + int device_ordinal, + se::DeviceMemoryBase device_memory, + const Shape& shape, + const string& tag) { + tensorflow::mutex_lock lock(allocation_mutex_); + VLOG(2) << "Register"; + return RegisterInternal(backend, device_ordinal, device_memory, shape, tag, + /*initial_ref_count=*/1); +} + +GlobalDataHandle AllocationTracker::RegisterInternal( + Backend* backend, int device_ordinal, se::DeviceMemoryBase device_memory, + const Shape& shape, const string& tag, int initial_ref_count) { + VLOG(2) << "RegisterInternal(" + << "tag: \"" << tag << "\" " + << "device_ordinal: " << device_ordinal << " " + << "device_memory: " << device_memory.opaque() << " " + << "shape: " << shape.ShortDebugString() << ")"; + TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + + int64 handle; + HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); + auto handle_it = handle_map.find(device_memory.opaque()); + if (handle_it != handle_map.end()) { + handle = handle_it->second; + auto& allocation = FindOrDie(handle_to_allocation_, handle); + int ref_count = allocation->ref_count(); + CHECK_GT(ref_count, 0); + VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1; + allocation->increment_ref_count(); + } else { + handle = next_handle_++; + VLOG(2) << "ref_count: " << initial_ref_count; + InsertOrDie(&handle_map, device_memory.opaque(), handle); + auto inserted = handle_to_allocation_.emplace( + handle, MakeUnique(backend, device_ordinal, device_memory, + shape, tag, initial_ref_count)); + CHECK(inserted.second); + } + + GlobalDataHandle result; + result.set_handle(handle); + VLOG(2) << "handle: " << handle; + + return result; +} + +tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { + tensorflow::mutex_lock lock(allocation_mutex_); + TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); + std::set deallocated_buffers; + TF_RETURN_IF_ERROR( + DeallocateShape(allocation->backend(), allocation->device_ordinal(), + allocation->mutable_device_memory(), allocation->shape(), + &deallocated_buffers)); + return tensorflow::Status::OK(); +} + +tensorflow::Status AllocationTracker::DeallocateShape( + Backend* backend, int device_ordinal, se::DeviceMemoryBase* device_memory, + const Shape& shape, std::set* deallocated_buffers) { + VLOG(2) << "DeallocateShape(" + << "shape: \"" << shape.ShortDebugString() << "\" " + << "device_memory: " << device_memory->opaque() << ")"; + if (ContainsKey(*deallocated_buffers, device_memory->opaque())) { + // Buffer has already been deallocated. Nothing to do. + VLOG(2) << "already deallocated"; + return tensorflow::Status::OK(); + } + + // Add buffer to deallocated set so we do not try to deallocate it again + // if it is encountered again while traversing a tuple. + deallocated_buffers->insert(device_memory->opaque()); + + HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); + auto handle_it = handle_map.find(device_memory->opaque()); + if (handle_it != handle_map.end()) { + int64 handle = handle_it->second; + auto& allocation = FindOrDie(handle_to_allocation_, handle); + int ref_count = allocation->ref_count(); + VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count - 1; + allocation->decrement_ref_count(); + if (allocation->ref_count() > 0) { + // Buffer is referred to by another allocation. Don't deallocate it. + return tensorflow::Status::OK(); + } + handle_map.erase(device_memory->opaque()); + } + + if (ShapeUtil::IsTuple(shape)) { + // Traverse into tuple recursively deallocating buffers. + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend->stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN(std::vector elements, + backend->transfer_manager()->ShallowCopyTupleFromDevice( + executor, *device_memory, shape)); + + TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) + << "tuple has unexpected number of elements: " << elements.size() + << " != " << ShapeUtil::TupleElementCount(shape); + for (int i = 0; i < elements.size(); ++i) { + VLOG(2) << "recursing onto the tuple elements"; + TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], + shape.tuple_shapes(i), + deallocated_buffers)); + } + } + + return backend->memory_allocator()->Deallocate(device_ordinal, device_memory); +} + +StatusOr> AllocationTracker::DeconstructTuple( + const GlobalDataHandle& data) { + tensorflow::mutex_lock lock(allocation_mutex_); + TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); + + if (!ShapeUtil::IsTuple(allocation->shape())) { + return InvalidArgument("global data handle %lld is not a tuple", + data.handle()); + } + + if (ShapeUtil::IsNestedTuple(allocation->shape())) { + return Unimplemented("deconstructing nested tuples not yet supported"); + } + + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + allocation->backend()->stream_executor(allocation->device_ordinal())); + TF_ASSIGN_OR_RETURN( + std::vector element_bases, + allocation->backend()->transfer_manager()->ShallowCopyTupleFromDevice( + executor, allocation->device_memory(), allocation->shape())); + + std::vector element_handles; + for (int i = 0; i < element_bases.size(); ++i) { + element_handles.push_back(RegisterInternal( + allocation->backend(), allocation->device_ordinal(), element_bases[i], + ShapeUtil::GetSubshape(allocation->shape(), {i}), + tensorflow::strings::StrCat(allocation->tag(), ".element_", i), + /*initial_ref_count=*/2)); + } + return std::move(element_handles); +} + +StatusOr AllocationTracker::Resolve( + const GlobalDataHandle& data) { + tensorflow::mutex_lock lock(allocation_mutex_); + return AllocationTracker::ResolveInternal(data); +} + +StatusOr AllocationTracker::ResolveInternal( + const GlobalDataHandle& data) { + VLOG(2) << "resolve:" << data.handle(); + auto it = handle_to_allocation_.find(data.handle()); + if (it == handle_to_allocation_.end()) { + return NotFound("no allocation record for global data handle: %lld", + data.handle()); + } + Allocation* allocation = it->second.get(); + + if (allocation->is_deallocated()) { + return InvalidArgument("global data handle %lld was previously deallocated", + data.handle()); + } + + return allocation; +} + +AllocationTracker::HandleMap& AllocationTracker::GetOrCreateOpaqueToHandleMap( + int device_ordinal) { + if (opaque_to_handle_.size() <= device_ordinal) { + opaque_to_handle_.resize(device_ordinal + 1); + } + return opaque_to_handle_[device_ordinal]; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h new file mode 100644 index 0000000000..e007680016 --- /dev/null +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A global allocation in device space, tracked by the XLA service. +class Allocation { + public: + Allocation(Backend* backend, int device_ordinal, + perftools::gputools::DeviceMemoryBase device_memory, + const Shape& shape, const string& tag, int initial_ref_count) + : backend_(backend), + device_ordinal_(device_ordinal), + device_memory_(device_memory), + shape_(shape), + tag_(tag), + ref_count_(initial_ref_count) {} + + Backend* backend() const { return backend_; } + int device_ordinal() const { return device_ordinal_; } + perftools::gputools::DeviceMemoryBase device_memory() const { + return device_memory_; + } + const Shape& shape() const { return shape_; } + const string& tag() const { return tag_; } + + bool is_deallocated() const { + CHECK_GE(ref_count_, 0); + return ref_count_ == 0; + } + int ref_count() const { + CHECK_GE(ref_count_, 0); + return ref_count_; + } + void increment_ref_count() { + CHECK_GT(ref_count_, 0); + CHECK_LT(ref_count_, INT_MAX); + ++ref_count_; + } + void decrement_ref_count() { + CHECK_GT(ref_count_, 0); + --ref_count_; + } + perftools::gputools::DeviceMemoryBase* mutable_device_memory() { + return &device_memory_; + } + + private: + // The backend that the memory is allocated on. + Backend* backend_; + + // The device that the memory is allocated on. + int device_ordinal_; + + // The pointer to this allocation. + perftools::gputools::DeviceMemoryBase device_memory_; + + // The shape of this allocation. + Shape shape_; + + // An informal description of this allocation shown in tools. + string tag_; + + // This is the number of Allocation objects which refer to this memory + // allocation. + int ref_count_; + + // Return a string representation of this allocation for debugging or logging + // purposes. + string ToString() const; +}; + +// Tracks allocations for the XLA service; allocations can be registered +// with shape/device/tag and resolved from a handle for later use. +class AllocationTracker { + public: + AllocationTracker(); + + // Registers device memory with a given shape, device identifier, and tag, and + // returns a corresponding handle that can be used for talking to XLA + // clients. + GlobalDataHandle Register(Backend* backend, int device_ordinal, + perftools::gputools::DeviceMemoryBase device_memory, + const Shape& shape, const string& tag); + + // Unregister the allocation for the given data handle. + tensorflow::Status Unregister(const GlobalDataHandle& data); + + // Returns a vector of global data handles that point to the tuple elements. + StatusOr> DeconstructTuple( + const GlobalDataHandle& Data); + + // Resolve a handle from an XLA client to an allocation, or provide an + // error status to say whether it was not found (or found, but found + // deallocated). + StatusOr Resolve(const GlobalDataHandle& data); + + private: + // Internal helper which resolves the given GlobalDataHandle to an Allocation. + StatusOr ResolveInternal(const GlobalDataHandle& data) + EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); + + GlobalDataHandle RegisterInternal( + Backend* backend, int device_ordinal, + perftools::gputools::DeviceMemoryBase device_memory, const Shape& shape, + const string& tag, int initial_ref_count) + EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); + + // Helper function which deallocates the memory buffer containing the given + // shape referred to by device_memory. Tuples are traversed recursively + // deallocating all nested buffers. The parameter deallocated_buffers contains + // the set of buffers deallocated so far stored as opaque values (void *) from + // DeviceMemoryBase. Keeping track of deallocated buffers prevents + // double-freeing of buffers which may be referred to more than once in a + // nested tuple. + tensorflow::Status DeallocateShape( + Backend* backend, int device_ordinal, + perftools::gputools::DeviceMemoryBase* device_memory, const Shape& shape, + std::set* deallocated_buffers) + EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); + + // Returns the opaque_to_handle_ map for the given device_ordinal, creating + // a new map if there is not one for the device_ordinal. + using HandleMap = std::map; + HandleMap& GetOrCreateOpaqueToHandleMap(int device_ordinal) + EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); + + tensorflow::mutex allocation_mutex_; // Guards the allocation mapping. + + // The next handle to assign to an allocation, guarded by the same mutex as + // the mapping as they'll be mutated at the same time. + int64 next_handle_ GUARDED_BY(allocation_mutex_); + + // A map from DeviceMemoryBase to handle for each device_ordinal. + std::vector opaque_to_handle_ GUARDED_BY(allocation_mutex_); + + // Mapping from GlobalDataHandle handle to the corresponding registered + // Allocation object. + std::map> handle_to_allocation_ + GUARDED_BY(allocation_mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc new file mode 100644 index 0000000000..6e76c98c9f --- /dev/null +++ b/tensorflow/compiler/xla/service/backend.cc @@ -0,0 +1,237 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/backend.h" + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct Backend::EigenThreadPoolWrapper { + explicit EigenThreadPoolWrapper() + : pool(new tensorflow::thread::ThreadPool( + tensorflow::Env::Default(), "XLAEigen", + tensorflow::port::NumSchedulableCPUs())), + wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + device(new Eigen::ThreadPoolDevice(wrapper.get(), + wrapper->NumThreads())) {} + + std::unique_ptr pool; + std::unique_ptr wrapper; + std::unique_ptr device; +}; + +/* static */ StatusOr> Backend::CreateBackend( + perftools::gputools::Platform* platform, int64 replica_count) { + if (replica_count == -1) { + legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); + replica_count = flags->xla_replicas; + } + TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); + TF_ASSIGN_OR_RETURN(auto stream_executors, + PlatformUtil::GetStreamExecutors(platform)); + TF_ASSIGN_OR_RETURN(auto transfer_manager, + TransferManager::GetForPlatform(platform)); + std::unique_ptr backend(new Backend( + replica_count, platform, compiler, stream_executors, transfer_manager)); + TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool, + backend->default_stream_executor())); + return std::move(backend); +} + +/* static */ StatusOr> +Backend::CreateDefaultBackend() { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetDefaultPlatform()); + return CreateBackend(platform); +} + +tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) { + std::vector> primed; + for (int i = 0; i < n; ++i) { + TF_ASSIGN_OR_RETURN(auto stream, AcquireStream(executor)); + primed.emplace_back(std::move(stream)); + } + for (int i = 0; i < n; ++i) { + ReleaseStream(std::move(primed.back())); + primed.pop_back(); + } + return tensorflow::Status::OK(); +} + +StatusOr> Backend::AcquireStream( + perftools::gputools::StreamExecutor* executor) { + tensorflow::mutex_lock lock(mutex_); + auto& cached_streams = cached_streams_[executor]; + if (!cached_streams.empty()) { + auto result = std::move(cached_streams.back()); + cached_streams.pop_back(); + return std::move(result); + } + + auto stream = MakeUnique(executor); + if (!stream->Init().ok()) { + return InternalError("failed to initialize stream"); + } + return std::move(stream); +} + +void Backend::ReleaseStream( + std::unique_ptr stream) { + tensorflow::mutex_lock lock(mutex_); + auto& streams = cached_streams_[stream->parent()]; + streams.emplace_back(std::move(stream)); +} + +Backend::Backend( + int64 replica_count, perftools::gputools::Platform* platform, + Compiler* compiler, + tensorflow::gtl::ArraySlice stream_executors, + TransferManager* transfer_manager) + : platform_(platform), + compiler_(compiler), + transfer_manager_(transfer_manager), + replica_count_(replica_count) { + // The given set of stream executors set may include invalid executors. + for (se::StreamExecutor* exec : stream_executors) { + if (exec != nullptr) { + stream_executors_.push_back(exec); + } + } + CHECK_GE(replica_count, 1) << "Must request at least 1 replica."; + + // Create a memory allocator for the valid stream executors. + memory_allocator_ = + MakeUnique(platform, stream_executors); + + // First check that there are some non-null stream executors to avoid issuing + // an error mentioning replicas in the common case of requesting just 1 + // replica, which means no replication. + CHECK(!stream_executors_.empty()) + << "Service found no devices for backend " << platform_->Name() << '.'; + CHECK_GE(stream_executors_.size(), replica_count) + << "Requested more replicas than there are devices for backend " + << platform_->Name() << '.'; + + if (platform->id() == se::host::kHostPlatformId) { + inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( + tensorflow::Env::Default(), "xla_inter_op", + tensorflow::port::NumSchedulableCPUs())); + intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper()); + } +} + +Backend::~Backend() {} + +int Backend::default_device_ordinal() const { + return default_stream_executor()->device_ordinal(); +} + +StatusOr> Backend::Replicas( + int device_ordinal) const { + if (stream_executors_[device_ordinal] == nullptr) { + return InvalidArgument("device %s not supported by XLA service", + device_name(device_ordinal).c_str()); + } + + // Find replica_count_ stream executors starting from the given device + // ordinal. + std::vector replicas; + for (se::StreamExecutor* exec : stream_executors_) { + CHECK(exec != nullptr); + if (exec->device_ordinal() >= device_ordinal) { + replicas.push_back(exec); + if (replicas.size() >= replica_count_) { + return replicas; + } + } + } + + return InvalidArgument( + "Not enough devices for replicas for the device ordinal %d", + device_ordinal); +} + +std::vector Backend::Replicas() const { + CHECK_GE(stream_executors_.size(), replica_count_); + return Replicas(default_device_ordinal()).ValueOrDie(); +} + +tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { + return inter_op_thread_pool_.get(); +} + +const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() + const { + if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + return intra_op_thread_pool_wrapper_->device.get(); +} + +StatusOr Backend::stream_executor( + int device_ordinal) const { + if (device_ordinal < 0 || + device_ordinal > stream_executors_.back()->device_ordinal()) { + return InvalidArgument( + "Invalid device ordinal value (%d). Valid range is [0, %d].", + device_ordinal, stream_executors_.back()->device_ordinal()); + } + for (auto* executor : stream_executors_) { + if (executor->device_ordinal() == device_ordinal) { + return executor; + } + } + return InvalidArgument("device %s not supported by XLA service", + device_name(device_ordinal).c_str()); +} + +StatusOr Backend::devices_equivalent(int device_ordinal_a, + int device_ordinal_b) { + // Use the name from device description to determine equivalence. This is a + // bit crude but works for GPUs which is the important case where we compile + // an executable for one GPU and want to know if it will run (well) on + // another. + TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_a, + stream_executor(device_ordinal_a)); + TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_b, + stream_executor(device_ordinal_b)); + return (executor_a->GetDeviceDescription().name() == + executor_b->GetDeviceDescription().name()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h new file mode 100644 index 0000000000..17c53d299e --- /dev/null +++ b/tensorflow/compiler/xla/service/backend.h @@ -0,0 +1,191 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace Eigen { +class ThreadPoolDevice; +} + +namespace xla { + +// Class which encapsulates an XLA backend. It includes everything necessary +// to compile and execute computations on a particular platform. +// +// It also offers a pooling API for creation/use of initialized streams: +// +// std::unique_ptr stream = +// backend->AcquireStream().ConsumeValueOrDie(); +// // ... use stream ... +// backend->ReleaseStream(std::move(stream)); +class Backend { + public: + // The number of streams we create for the pool at initialization time. + static constexpr int kInitialStreamsToPool = 8; + + // Creates a new backend for the given platform with the given number of + // replicas. A value of -1 means to use the flag value. + static StatusOr> CreateBackend( + perftools::gputools::Platform* platform, int64 replica_count = -1); + + // Creates a backend for the default platform. The default platform is defined + // in PlatformUtil. + static StatusOr> CreateDefaultBackend(); + + ~Backend(); + + // Accessors for the various objects. + perftools::gputools::Platform* platform() const { return platform_; } + Compiler* compiler() const { return compiler_; } + DeviceMemoryAllocator* memory_allocator() const { + return memory_allocator_.get(); + } + TransferManager* transfer_manager() const { return transfer_manager_; } + + // Returns the number of devices of the platform type which are visible. Not + // all of these devices may be usable by XLA. + int device_count() const { return stream_executors_.size(); } + + // Returns the device ordinal number of the default device. + int default_device_ordinal() const; + + // Returns stream executors of all supported devices for this backend. The + // executors are ordered by the device ordinal. + const std::vector& stream_executors() + const { + return stream_executors_; + } + + // Returns the replicas for the default stream executor. + // + // When the number of replicas is R, the first R stream executors are assigned + // to the replicas of the default stream executor. + std::vector Replicas() const; + + // Returns the replicas for the given device_ordinal. The given device ordinal + // is considered to be the first device ordinal among the replicas. Returns an + // error status if the stream executor for the given given device ordinal does + // not exist or if there are not enough stream executors for the replicas. + StatusOr> Replicas( + int device_ordinal) const; + + // Return the stream executor for the given device ordinal. + StatusOr stream_executor( + int device_ordinal) const; + + // Return the stream executor for the default device ordinal. + perftools::gputools::StreamExecutor* default_stream_executor() const { + CHECK(!stream_executors_.empty()); + return stream_executors_[0]; + } + + // Primes the internal pool of streams for AcquireStream/ReleaseStream with n + // initialized stream instances. + tensorflow::Status PoolStreams(int n, + perftools::gputools::StreamExecutor* executor); + + // Acquires a stream for use by the caller, either by grabbing it from an + // internal pool, or by constructing/initializating it, and returns the result + // to the caller. + // + // TODO(b/32989582): Return std::unique_ptr with custom deleter. + StatusOr> AcquireStream( + perftools::gputools::StreamExecutor* executor); + + // Releases a stream from the caller to the internal pool, for use with the + // paired AcquireStream above. + void ReleaseStream(std::unique_ptr stream); + + // Returns whether the given device ordinal of the backend is supported. + bool device_ordinal_supported(int device_ordinal) const { + return (device_ordinal >= 0 && device_ordinal < device_count() && + stream_executors_[device_ordinal] != nullptr); + } + + // Return a string identifier for the given device, eg: "GPU:3". + string device_name(int device_ordinal) const { + return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal); + } + + // Returns true if the devices with the given ordinals are equivalent from + // XLA's perspective. That is, an executable compiled for one device would + // be equivalent to an executable compiled for the other. + StatusOr devices_equivalent(int device_ordinal_a, int device_ordinal_b); + + // For the host platform, returns the threadpool to use when scheduling + // parallel operators. For other platforms, returns NULL. + tensorflow::thread::ThreadPool* inter_op_thread_pool() const; + + // For the host platform, returns the configured eigen threadpool device to be + // used for scheduling work. For other platforms, returns NULL. + const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; + + private: + struct EigenThreadPoolWrapper; + Backend(int64 replica_count, perftools::gputools::Platform* platform, + Compiler* compiler, + tensorflow::gtl::ArraySlice + stream_executors, + TransferManager* transfer_manager); + Backend(const Backend&) = delete; + Backend& operator=(const Backend&) = delete; + + perftools::gputools::Platform* platform_; + Compiler* compiler_; + TransferManager* transfer_manager_; + int64 replica_count_ = -1; + + // Vector of stream executors. stream_executors_[0] is the default executor. + std::vector stream_executors_; + + // Guards the mutable state in the backend object. + tensorflow::mutex mutex_; + + // Mapping from stream executor to cached streams, used by + // AcquireStream/ReleaseStream above. + std::map>> + cached_streams_ GUARDED_BY(mutex_); + + // The default memory allocator to use. + std::unique_ptr memory_allocator_; + + // For the CPU backend, a threadpool for scheduling parallel operators. + std::unique_ptr inter_op_thread_pool_; + + // For the CPU backend, an Eigen threadpool device for use by Eigen code. + std::unique_ptr intra_op_thread_pool_wrapper_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_ diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc new file mode 100644 index 0000000000..1616a1363d --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -0,0 +1,777 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines the data returned by the XLA buffer assignment packages. + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { + +void BufferAllocation::AddAssignment(const LogicalBuffer& buffer) { + DCHECK(std::find(assigned_buffers_.begin(), assigned_buffers_.end(), + &buffer) == assigned_buffers_.end()) + << "LogicalBuffer " << buffer.ToString() + << " already assigned to allocation " << index(); + assigned_buffers_.push_back(&buffer); +} + +string BufferAllocation::ToString() const { + string output; + tensorflow::strings::StrAppend( + &output, tensorflow::strings::Printf("allocation %lld: %p, size %lld", + index_, this, size())); + if (is_entry_computation_parameter()) { + tensorflow::strings::StrAppend(&output, ", parameter ", parameter_number()); + } + if (is_thread_local()) { + tensorflow::strings::StrAppend(&output, ", thread-local"); + } + tensorflow::strings::StrAppend(&output, ":\n"); + for (const auto& buffer : assigned_buffers()) { + tensorflow::strings::StrAppend( + &output, + tensorflow::strings::Printf( + " %s::%s : %s\n", buffer->instruction()->parent()->name().c_str(), + buffer->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + } + return output; +} + +std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) { + out << buffer.ToString(); + return out; +} + +const PointsToSet& BufferAssignment::GetPointsToSet( + const HloInstruction* instruction) const { + return points_to_analysis().GetPointsToSet(instruction); +} + +bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const { + TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); + return allocation_index_for_buffer_.count(&buffer) > 0; +} + +const BufferAllocation& BufferAssignment::GetAssignedAllocation( + const LogicalBuffer& buffer) const { + CHECK(HasAllocation(buffer)); + return GetAllocation(allocation_index_for_buffer_.at(&buffer)); +} + +BufferAllocation* BufferAssignment::GetMutableAssignedAllocation( + const LogicalBuffer& buffer) { + return const_cast(&GetAssignedAllocation(buffer)); +} + +std::set BufferAssignment::GetAllocations( + const HloInstruction* instruction, const ShapeIndex& index) const { + std::set allocations; + for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) { + if (allocation_index_for_buffer_.count(buffer) > 0) { + allocations.insert( + GetAllocation(allocation_index_for_buffer_.at(buffer))); + } + } + return allocations; +} + +const BufferAllocation& BufferAssignment::GetAllocation( + BufferAllocation::Index index) const { + CHECK(index >= 0 && index < allocations_.size()) + << "Allocation index " << index << "is out of range."; + return allocations_[index]; +} + +BufferAllocation* BufferAssignment::GetMutableAllocation( + BufferAllocation::Index index) { + return const_cast(&GetAllocation(index)); +} + +bool BufferAssignment::HasTopLevelAllocation( + const HloInstruction* instruction) const { + for (const LogicalBuffer* buffer : + GetPointsToSet(instruction).element(/*index=*/{})) { + if (allocation_index_for_buffer_.count(buffer) > 0) { + return true; + } + } + return false; +} + +StatusOr BufferAssignment::GetUniqueAllocation( + const HloInstruction* instruction, const ShapeIndex& index) const { + const BufferAllocation* allocation = nullptr; + for (const LogicalBuffer* buffer : + GetPointsToSet(instruction).element(index)) { + if (HasAllocation(*buffer)) { + if (allocation != nullptr && + *allocation != GetAssignedAllocation(*buffer)) { + return FailedPrecondition( + "LogicalBuffer allocation for instruction %s at index {%s} cannot " + "be determined at compile-time.", + instruction->name().c_str(), + tensorflow::str_util::Join(index, ",").c_str()); + } + allocation = &GetAssignedAllocation(*buffer); + } + } + if (allocation == nullptr) { + return FailedPrecondition( + "instruction %s has no buffer allocation at index {%s}", + instruction->name().c_str(), + tensorflow::str_util::Join(index, ",").c_str()); + } + return allocation; +} + +StatusOr BufferAssignment::GetUniqueTopLevelAllocation( + const HloInstruction* instruction) const { + return GetUniqueAllocation(instruction, /*index=*/{}); +} + +StatusOr +BufferAssignment::GetUniqueTopLevelOutputAllocation() const { + return GetUniqueTopLevelAllocation( + module_->entry_computation()->root_instruction()); +} + +BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, + int64 size, + bool is_thread_local) { + BufferAllocation::Index index = allocations_.size(); + allocations_.emplace_back(index, size, is_thread_local); + BufferAllocation* allocation = &allocations_.back(); + AddAssignment(buffer, allocation); + allocation_index_for_buffer_[&buffer] = index; + return allocation; +} + +// Adds an instruction to the set assigned to the given buffer. +void BufferAssignment::AddAssignment(const LogicalBuffer& buffer, + BufferAllocation* allocation) { + CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) + << "LogicalBuffer " << buffer << " already has an allocation."; + TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); + + allocation->AddAssignment(buffer); + allocation_index_for_buffer_[&buffer] = allocation->index(); +} + +string BufferAssignment::ToString() const { + string output; + tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); + for (auto& allocation : allocations_) { + tensorflow::strings::StrAppend(&output, allocation.ToString()); + } + return output; +} + +namespace { + +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +tensorflow::Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations) { + // Create a worklist of computations paired with whether the allocation must + // be thread-local. + std::deque> worklist; + worklist.push_back(std::make_pair(module->entry_computation(), + /*is_thread_local*/ false)); + + // Sets for quickly checking membership. Computations are returned in vectors + // for stable iteration. + std::unordered_set thread_local_set; + std::unordered_set global_set; + + while (!worklist.empty()) { + auto worklist_front = worklist.front(); + worklist.pop_front(); + HloComputation* computation = worklist_front.first; + bool is_thread_local = worklist_front.second; + bool in_thread_local_set = thread_local_set.count(computation) > 0; + bool in_global_set = global_set.count(computation) > 0; + + // If the computation has already been added to the respective set, then + // nothing to do. + if ((is_thread_local && in_thread_local_set) || + (!is_thread_local && in_global_set)) { + continue; + } + + // If the computation has already been added to the other set this is an + // error condition because the global call to the computation (eg, + // while/call) may return a reference to one of the thread-local buffers to + // the calling computation which will become a dangling reference when the + // thread-local is deallocated with the call return. + if ((is_thread_local && in_global_set) || + (!is_thread_local && in_thread_local_set)) { + return InvalidArgument( + "computation %s has conflicting allocation requirements (global " + "and thread-local)", + computation->name().c_str()); + } + + if (is_thread_local) { + thread_local_set.insert(computation); + } else { + global_set.insert(computation); + } + + for (auto& instruction : computation->instructions()) { + for (auto* subcomputation : instruction->MakeCalledComputationsSet()) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kWhile: + // Call and while must be called from a computation with global + // allocations as they may return references to buffers inside the + // called computation which cannot be thread-local. + if (is_thread_local) { + return InvalidArgument( + "computation %s cannot contain call/while op because it " + "requires thread-local buffer allocations", + computation->name().c_str()); + } + worklist.push_back(std::make_pair(subcomputation, + false)); // Not thread local. + break; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + // Map/reduce etc computations are always thread-local. + worklist.push_back(std::make_pair(subcomputation, + true)); // Thread local. + break; + default: + return InternalError( + "Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode()).c_str()); + } + } + } + } + + // Add the computations to the vectors in post order. + for (auto* computation : module->MakeComputationPostOrder()) { + if (thread_local_set.count(computation) > 0) { + thread_local_computations->push_back(computation); + } else if (global_set.count(computation) > 0) { + global_computations->push_back(computation); + } + // If the computation is not reachable from the entry computation, then it + // will not appear in either thread_local_set or global_set. We don't bother + // assigning buffers for these. + } + return tensorflow::Status::OK(); +} + +} // namespace + +/* static */ +StatusOr> BufferAssigner::Run( + const HloModule* module, std::unique_ptr hlo_ordering, + BufferSizeFunction buffer_size, bool colocate_related_buffers, + const std::vector* hlos_to_allocate) { + BufferAssigner assigner(std::move(buffer_size), colocate_related_buffers); + return assigner.CreateAssignment(module, std::move(hlo_ordering), + hlos_to_allocate); +} + +/* static */ +StatusOr> BufferAssigner::Run( + const HloModule* module, std::unique_ptr hlo_ordering, + int64 pointer_size) { + return BufferAssigner::Run(module, std::move(hlo_ordering), + [pointer_size](const LogicalBuffer& buffer) { + return ShapeUtil::IsOpaque(buffer.shape()) + ? 0 + : ShapeUtil::ByteSizeOf( + buffer.shape(), pointer_size); + }, + /*colocate_related_buffers=*/true); +} + +bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, + const LogicalBuffer& buffer, + BufferAssignment* assignment) { + CHECK(!assignment->HasAllocation(buffer)) + << "buffer " << buffer << " already has an allocation assigned."; + + VLOG(4) << "Trying to assign " << buffer.ToString() + << " to allocation: " << allocation->ToString(); + + if (buffer_size_(buffer) > allocation->size()) { + VLOG(4) << "Can't assign: buffer is larger than allocation (" + << buffer_size_(buffer) << " > " << allocation->size() << ")"; + return false; + } + + if (allocation->is_entry_computation_parameter()) { + VLOG(4) << "Can't assign: allocation holds parameter"; + return false; + } + + for (const LogicalBuffer* assigned_buffer : allocation->assigned_buffers()) { + if (assignment->liveness().MayInterfere(*assigned_buffer, buffer)) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer->ToString() + << " may interfere with " << buffer.ToString(); + return false; + } + } + + // If the buffer is live out of the computation then it should only be + // assigned a buffer which exactly fits the result to avoid wasting memory + // (result buffers can have arbitrary lifetimes). + if (assignment->liveness().MaybeLiveOut(buffer) && + allocation->size() != buffer_size_(buffer)) { + VLOG(4) << "Can't assign: buffer " << buffer.ToString() + << "is live out and size not the same as allocation"; + return false; + } + + assignment->AddAssignment(buffer, allocation); + return true; +} + +tensorflow::Status BufferAssigner::AssignBuffersForComputation( + const HloComputation* computation, bool is_thread_local, + const std::unordered_set* hlos_to_allocate, + BufferAssignment* assignment) { + // Buffers are sorted and assigned to BufferAllocations in decreasing order of + // size. + std::vector sorted_buffers; + for (auto& instruction : computation->instructions()) { + if (hlos_to_allocate == nullptr || + hlos_to_allocate->count(instruction.get()) > 0) { + // Add all buffers which this instruction defines. Instruction which don't + // define buffers (eg, bitcast which just forwards a pointer) don't need + // any allocations. + for (const LogicalBuffer* buffer : + assignment->points_to_analysis().GetBuffersDefinedByInstruction( + instruction.get())) { + sorted_buffers.push_back(buffer); + } + } + } + + // Generate a post order sort of instructions for sorting of the + // LogicalBuffers. + std::unordered_map post_order_position; + int position = 0; + for (auto* instruction : computation->MakeInstructionPostOrder()) { + post_order_position.emplace(instruction, position); + position++; + } + + // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers + // first for simplicity. This means any previously created BufferAllocation is + // necessarily large enough to hold the output of the current Buffer in + // consideration. + // + // As a secondary sorting criteria, use post order position of the HLO + // instruction which defines the buffer. This means an instruction will appear + // after its operands (assuming operands are the same/larger size) enabling + // the important reuse case where an elementwise instruction reuses one of its + // operand's buffer. This improves locality. + std::sort(sorted_buffers.begin(), sorted_buffers.end(), + [this, &post_order_position](const LogicalBuffer* a, + const LogicalBuffer* b) { + int64 a_size = buffer_size_(*a); + int64 b_size = buffer_size_(*b); + if (a_size == b_size) { + // For instructions with the same size buffers, sort them in + // post order. + return post_order_position.at(a->instruction()) < + post_order_position.at(b->instruction()); + } else { + // We want the HLOs sorted in reverse order by size so use ">". + return a_size > b_size; + } + }); + + // BufferAllocations are necessarily created in decreasing size order. Keep + // indices of previously created BufferAllocations in allocation_indices. + std::vector allocation_indices; + for (const auto* buffer : sorted_buffers) { + VLOG(3) << "Assigning allocation to: " << buffer->ToString(); + if (colocated_buffers_.find(buffer) != colocated_buffers_.end()) { + // Colocated buffers are currently assigned in an earlier pass. + continue; + } + + TF_RET_CHECK(!assignment->HasAllocation(*buffer)); + + if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + // No BufferAllocations for constants. + // TODO(b/32248867): For consistency, constants should get allocations. + continue; + } + + if (buffer->instruction()->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation()) { + // If the LogicalBuffer is part of an external parameter, creates a new + // allocation and sets its parameter number. Parameters of non-entry + // computations do not need special allocations because they live inside + // callers. + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size_(*buffer), + /*is_thread_local=*/false); + allocation->set_entry_computation_parameter( + buffer->instruction()->parameter_number()); + VLOG(3) << "New allocation for entry computation parameter: " + << buffer->ToString(); + continue; + } + + legacy_flags::BufferAssignmentFlags* flags = + legacy_flags::GetBufferAssignmentFlags(); + if (!flags->xla_enable_buffer_reuse || is_thread_local || + buffer->instruction()->opcode() == HloOpcode::kCustomCall) { + // Custom call operations never have reusable buffers. Also we do not + // reuse thread-local buffers for now, because they are dynamically + // allocated and their lifetimes are hard to compute. + assignment->NewAllocation(*buffer, buffer_size_(*buffer), + is_thread_local); + continue; + } + + if (buffer->instruction()->opcode() == HloOpcode::kCall && + buffer->IsTopLevel()) { + // Assign the kCall instruction the same allocation as the root of the + // called computation. The points-to set of the root of the called + // computation must be unambigous so we know statically the allocation for + // the root. + // + // TODO(b/32491382): This is a hack. To properly handle this case + // points-to analysis, liveness analysis, and buffer assignment need to + // module-scope rather than computation-scope. + HloInstruction* call = buffer->instruction(); + HloInstruction* computation_root = call->to_apply()->root_instruction(); + + // The buffer of the root of the called computation must be unambiguous. + const auto& root_points_to = assignment->GetPointsToSet(computation_root); + if (root_points_to.IsAmbiguous()) { + return Unimplemented( + "kCall of a computation with an ambiguous root points-to set"); + } + CHECK_EQ(1, root_points_to.element(/*index=*/{}).size()); + const LogicalBuffer* root_buffer = + root_points_to.element(/*index=*/{})[0]; + BufferAllocation* root_allocation = + assignment->GetMutableAssignedAllocation(*root_buffer); + + // Can't use MaybeAssignBuffer here because buffer liveness conservatively + // assumes buffers in different computations always interfere. + CHECK_GE(root_allocation->size(), buffer_size_(*buffer)); + assignment->AddAssignment(*buffer, root_allocation); + continue; + } + + // First try to assign a LogicalBuffer to one of its operand allocations to + // improve locality. This is only possible with elementwise operations + // (checked in liveness analysis) which are necessarily top-level + // array-shaped buffers. + if (buffer->IsTopLevel() && !buffer->IsTuple()) { + for (auto* operand : buffer->instruction()->operands()) { + bool assigned_operand = false; + for (const auto& operand_allocation : + assignment->GetAllocations(operand, /*index=*/{})) { + BufferAllocation* allocation = + assignment->GetMutableAllocation(operand_allocation.index()); + if (colocated_buffer_allocations_.find(allocation->index()) == + colocated_buffer_allocations_.end()) { + // TODO(b/32491382) Colocated buffers are currently assigned in an + // earlier pass, and so can break the "increasing allocation size" + // invariant in this function (causing this CHECK to fail). However, + // the call to MaybeAssignBuffer is safe as it returns false if + // allocation.size < buffer.size. + CHECK_GE(allocation->size(), buffer_size_(*buffer)); + } + if (MaybeAssignBuffer(allocation, *buffer, assignment)) { + VLOG(3) << "Reusing (operand) allocation for: " + << buffer->ToString(); + assigned_operand = true; + break; + } + } + if (assigned_operand) { + break; + } + } + } + + if (!assignment->HasAllocation(*buffer)) { + // Find the smallest buffer which can be reused iterating from end of + // allocation_indices (smallest) to beginning (largest). + for (int allocation_index = allocation_indices.size() - 1; + allocation_index >= 0; allocation_index--) { + BufferAllocation* allocation = assignment->GetMutableAllocation( + allocation_indices[allocation_index]); + // Instructions are iterated in increasing buffer size, so any + // previously create allocation must be large enough to hold this + // instruction's output (with the exception of colocated buffers). + if (colocated_buffer_allocations_.find(allocation->index()) == + colocated_buffer_allocations_.end()) { + // TODO(b/32491382) Colocated buffers are currently assigned in an + // earlier pass, and so can break the "increasing allocation size" + // invariant in this function (causing this CHECK to fail). However, + // the call to MaybeAssignBuffer is safe as it returns false if + // allocation.size < buffer.size. + CHECK_GE(allocation->size(), buffer_size_(*buffer)); + } + + if (MaybeAssignBuffer(allocation, *buffer, assignment)) { + VLOG(3) << "Reusing buffer for: " << buffer->ToString(); + break; + } + } + } + if (!assignment->HasAllocation(*buffer)) { + auto* allocation = assignment->NewAllocation( + *buffer, buffer_size_(*buffer), is_thread_local); + VLOG(3) << "New allocation for: " << buffer->ToString(); + allocation_indices.push_back(allocation->index()); + } + } + return tensorflow::Status::OK(); +} + +void BufferAssigner::AddBufferToColocatedBufferSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + BufferAssigner::ColocatedBufferSet* colocated_buffer_set) { + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + // CopyInsertion ensures root points-to set is unambiguous and distinct. + CHECK(!points_to.IsAmbiguous()); + CHECK(points_to.IsDistinct()); + colocated_buffer_set->push_back(points_to.element(index)[0]); +} + +// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated +// in the same allocation (currently just supports kWhile). +std::vector +BufferAssigner::BuildColocatedBufferSets( + const HloModule* module, const TuplePointsToAnalysis& points_to_analysis) { + std::vector colocated_buffer_sets; + for (auto& computation : module->computations()) { + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + HloInstruction* while_hlo = instruction.get(); + TF_CHECK_OK(ShapeUtil::ForEachSubshape( + while_hlo->shape(), + [this, &points_to_analysis, &while_hlo, &colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + ColocatedBufferSet colocated_buffer_set; + // Add while.init. + AddBufferToColocatedBufferSet(while_hlo->operand(0), index, + points_to_analysis, + &colocated_buffer_set); + // Add while.result. + AddBufferToColocatedBufferSet(while_hlo, index, points_to_analysis, + &colocated_buffer_set); + // Add while.cond.parameter. + AddBufferToColocatedBufferSet( + while_hlo->while_condition()->parameter_instruction(0), index, + points_to_analysis, &colocated_buffer_set); + // Add while.body.parameter. + AddBufferToColocatedBufferSet( + while_hlo->while_body()->parameter_instruction(0), index, + points_to_analysis, &colocated_buffer_set); + // Add while.body.root. + AddBufferToColocatedBufferSet( + while_hlo->while_body()->root_instruction(), index, + points_to_analysis, &colocated_buffer_set); + + colocated_buffer_sets.push_back(std::move(colocated_buffer_set)); + return tensorflow::Status::OK(); + })); + } + } + return colocated_buffer_sets; +} + +// Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same +// allocation in 'assignment'. +void BufferAssigner::AssignColocatedBufferSets( + const std::vector& colocated_buffer_sets, + BufferAssignment* assignment) { + for (const auto& colocated_buffer_set : colocated_buffer_sets) { + BufferAllocation* allocation = nullptr; + for (const auto& buffer : colocated_buffer_set) { + if (colocated_buffers_.find(buffer) != colocated_buffers_.end()) { + // ColocatedBufferSet duplicates can occur if a buffer is forwarded + // from one instruction to another (i.e. while.body param to root). + continue; + } + if (allocation == nullptr) { + // TODO(b/32491382) Avoid current trivial solution of using new + // allocations for each colocated buffer set. When liveness has + // module-level scope, we can allow buffers to be shared across + // computations (in some cases). + allocation = assignment->NewAllocation(*buffer, buffer_size_(*buffer), + /*is_thread_local=*/false); + colocated_buffer_allocations_.insert(allocation->index()); + } else { + assignment->AddAssignment(*buffer, allocation); + } + colocated_buffers_.insert(buffer); + } + } +} + +StatusOr> BufferAssigner::CreateAssignment( + const HloModule* module, std::unique_ptr hlo_ordering, + const std::vector* hlos_to_allocate) { + TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, + BufferLiveness::Run(module, std::move(hlo_ordering))); + + std::vector thread_local_computations; + std::vector global_computations; + VLOG(1) << "Assigning buffers to module " << module->name(); + if (hlos_to_allocate != nullptr) { + VLOG(3) << "LogicalBuffer assignment restricted to hlos: "; + for (auto hlo : *hlos_to_allocate) { + VLOG(3) << " " << hlo->parent()->name() << "::" << hlo->name(); + } + } + XLA_VLOG_LINES(3, module->ToString()); + XLA_VLOG_LINES(3, liveness->ToString()); + XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); + + TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( + module, &thread_local_computations, &global_computations)); + + // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to + // AssignBuffersForComputation for fast membership testing. + std::unique_ptr> hlo_set; + if (hlos_to_allocate != nullptr) { + hlo_set = MakeUnique>( + hlos_to_allocate->begin(), hlos_to_allocate->end()); + } + + // Can't use MakeUnique because BufferAssignment constructor is private. + std::unique_ptr assignment( + new BufferAssignment(module, std::move(liveness))); + + // Assign buffers with the tightest constraints first (colocated buffer sets). + // Once b/32491382 enables module-level liveness analysis, we may be able + // to assign colocated buffers (or at least reuse their allocation for + // buffers outside of the set) in AssignBuffersForComputation. + if (colocate_related_buffers_) { + std::vector colocated_buffer_sets = + BuildColocatedBufferSets(module, assignment->points_to_analysis()); + AssignColocatedBufferSets(colocated_buffer_sets, assignment.get()); + } + + for (auto* computation : global_computations) { + TF_RETURN_IF_ERROR(AssignBuffersForComputation( + computation, + /*is_thread_local=*/false, hlo_set.get(), assignment.get())); + } + for (auto* computation : thread_local_computations) { + TF_RET_CHECK(computation != module->entry_computation()); + TF_RETURN_IF_ERROR(AssignBuffersForComputation( + computation, + /*is_thread_local=*/true, hlo_set.get(), assignment.get())); + } + + // Mark all buffers which may be live out of the entry computation as + // "liveout". + auto entry = module->entry_computation(); + auto root_instruction = entry->root_instruction(); + const PointsToSet& root_points_to = + assignment->GetPointsToSet(root_instruction); + TF_RETURN_IF_ERROR(root_points_to.ForEachElement([&assignment]( + const ShapeIndex& /*index*/, bool /*is_leaf*/, + const std::vector& buffers) { + for (auto buffer : buffers) { + if (assignment->HasAllocation(*buffer)) { + assignment->GetMutableAssignedAllocation(*buffer)->set_maybe_live_out( + true); + } + } + return tensorflow::Status::OK(); + })); + + XLA_VLOG_LINES(2, assignment->ToString()); + + // Compute sizes of various kinds of buffers for logging. + int64 total_size = 0; + int64 parameter_size = 0; + for (auto& allocation : assignment->Allocations()) { + if (allocation.is_entry_computation_parameter()) { + parameter_size += allocation.size(); + } + total_size += allocation.size(); + } + + // Compute the total size of the output. Iterate over the subshapes and sum up + // the sizes of the buffers for each subshape. + int64 output_size = 0; + HloInstruction* root = module->entry_computation()->root_instruction(); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( + root->shape(), [this, &output_size, root, &assignment]( + const Shape& /*subshape*/, const ShapeIndex& index) { + const auto& allocations = assignment->GetAllocations(root, index); + if (allocations.size() > 0) { + output_size += allocations.begin()->size(); + } + return tensorflow::Status::OK(); + })); + + VLOG(1) << "Allocation sizes for module " << module->name() << ":"; + VLOG(1) << " parameter allocation total size: " + << tensorflow::strings::HumanReadableNumBytes(parameter_size); + VLOG(1) << " output allocation total size: " + << tensorflow::strings::HumanReadableNumBytes(output_size); + VLOG(1) << " temp allocation total size: " + << tensorflow::strings::HumanReadableNumBytes( + total_size - parameter_size - output_size); + VLOG(1) << " total allocation size: " + << tensorflow::strings::HumanReadableNumBytes(total_size); + return std::move(assignment); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h new file mode 100644 index 0000000000..af455de298 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -0,0 +1,358 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// This class abstracts an allocation of contiguous memory which can hold the +// values described by LogicalBuffers. A BufferAllocation may hold different +// LogicalBuffers at different times, but currently never more than one +// LogicalBuffer simultaneously. The abstraction includes information required +// by the backends for allocation, use, and deallocation of the buffer. This +// includes the LogicalBuffers which are held in this allocation through the +// execution of the computation. +class BufferAllocation { + public: + // Holds a unique identifier for each allocation. Values are assigned + // contiguously and can be used as array indexes. + using Index = int64; + + BufferAllocation(Index index, int64 size, bool is_thread_local) + : index_(index), size_(size), is_thread_local_(is_thread_local) {} + ~BufferAllocation() {} + + // Adds a LogicalBuffer to the set assigned to this buffer. + void AddAssignment(const LogicalBuffer& buffer); + + // Whether this allocation is used in a parallel calling context such as + // inside of a map or reduce computation. Such allocations need to be thread + // local. + bool is_thread_local() const { return is_thread_local_; } + + // Whether this allocation holds a LogicalBuffer from a parameter of the entry + // computation. These buffers have lifetimes which may be longer than the + // XLA computation. + bool is_entry_computation_parameter() const { + return is_entry_computation_parameter_; + } + // If this allocation holds a Buffer from a parameter of the entry + // computation, this methods returns the parameter number. CHECKs otherwise. + int64 parameter_number() const { + CHECK(is_entry_computation_parameter_); + return parameter_number_; + } + // Sets that this allocation holds a LogicalBuffer from a parameter of the + // entry computation. + void set_entry_computation_parameter(int64 parameter_number) { + is_entry_computation_parameter_ = true; + parameter_number_ = parameter_number; + } + + // Returns/sets whether this allocation is assigned a LogicalBuffer which may + // be live out of the entry computation. + bool maybe_live_out() const { return maybe_live_out_; } + void set_maybe_live_out(bool value) { maybe_live_out_ = value; } + + // Returns the size of the allocation. Necessarily this must be at least as + // large as any LogicalBuffer assigned to this allocation. + int64 size() const { return size_; } + + // Access to the logical buffers assigned to this allocation. + const std::vector& assigned_buffers() const { + return assigned_buffers_; + } + + Index index() const { return index_; } + + string ToString() const; + + // Whether the buffer is a parameter to or live out of the entry computation. + bool IsInputOrOutput() const { + return is_entry_computation_parameter() || maybe_live_out(); + } + + // Whether the buffer is a temporary buffer allocated before + // Executable::ExecuteOnStream. + bool IsPreallocatedTempBuffer() const { + // Parameters do not need temporary buffers. + return !is_entry_computation_parameter() && + // LogicalBuffers that maybe pointed to by the output should live out + // of the computation. + !maybe_live_out() && + // Thread-local buffers are allocated using `alloca`s. + !is_thread_local(); + } + + bool operator==(const BufferAllocation& other) const { + return index_ == other.index_; + } + bool operator!=(const BufferAllocation& other) const { + return !(*this == other); + } + bool operator<(const BufferAllocation& other) const { + return index() < other.index(); + } + + private: + // The index of the allocation in the BufferAssignment. + Index index_; + + // Size of the allocation in bytes. + int64 size_; + + // Whether this buffer needs to be thread-local. + bool is_thread_local_; + + // Whether this allocation holds an entry computation parameter. Entry + // computation parameters are special be cause they have lifetimes which may + // outlast the computation. + bool is_entry_computation_parameter_ = false; + + // If this allocation holds an entry computation parameter, this field + // indicates the index (starting from 0) of the parameter. + int64 parameter_number_ = 0; + + // Whether the allocation contains a LogicalBuffer which may be live-out of + // the entry computation. Note that this flag is conservatively computed by + // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_` + // might not actually escape. + bool maybe_live_out_ = false; + + // The set of buffers assigned to this allocation. + std::vector assigned_buffers_; +}; + +// Add stream operator for nicer output of CHECK/RET_CHECK failures. +std::ostream& operator<<(std::ostream& out, const BufferAllocation& s); + +// This class encapsulates an assignment of the LogicalBuffers in an XLA +// module to a set of BufferAllocations. +class BufferAssignment { + public: + // Returns the vector containing all buffer allocations in this assignment. + const std::vector& Allocations() const { + return allocations_; + } + + // Returns whether the given buffer has been assigned an allocation. + bool HasAllocation(const LogicalBuffer& buffer) const; + + // Returns the allocation that a particular LogicalBuffer has been assigned + // to. CHECKs if buffer has not been assigned an allocation. + const BufferAllocation& GetAssignedAllocation( + const LogicalBuffer& buffer) const; + + // Returns the allocation with the given index. CHECKs if no allocation exists + // with the given index. + const BufferAllocation& GetAllocation(BufferAllocation::Index index) const; + + // Builds and returns a vector containing the allocations which might contain + // the subvalue at the given index of given instruction. + std::set GetAllocations(const HloInstruction* instruction, + const ShapeIndex& index) const; + + // Convenience function which returns whether the top-level buffer of the + // instruction (index == {}) is assigned an allocation. + bool HasTopLevelAllocation(const HloInstruction* instruction) const; + + // Convenience function which returns the unique buffer allocation containing + // the buffer at the given index of the given instruction. If an allocation is + // not assigned or the allocation cannot be determined at compile time then an + // error is returned. + StatusOr GetUniqueAllocation( + const HloInstruction* instruction, const ShapeIndex& index) const; + // Like GetUniqueAllocation but fixes the index to the top-level of the shape + // (index = {}). + StatusOr GetUniqueTopLevelAllocation( + const HloInstruction* instruction) const; + // Like GetUniqueTopLevelAllocation but returns the allocation for the output + // of the entry computation of the HLO module (ie, the result of the XLA + // computation). + StatusOr GetUniqueTopLevelOutputAllocation() const; + + // Returns the set LogicalBuffers which may be the source of the value at the + // given index and instruction. + const std::vector& GetSourceBuffers( + const HloInstruction* instruction, const ShapeIndex& index) const { + return GetPointsToSet(instruction).element(index); + } + + // Returns the underlying points-to analysis used for this assignment. + const TuplePointsToAnalysis& points_to_analysis() const { + return liveness_->points_to_analysis(); + } + + string ToString() const; + + private: + // Only BufferAssigner can build or modify BufferAssignments. + friend class BufferAssigner; + + explicit BufferAssignment(const HloModule* module, + std::unique_ptr liveness) + : module_(module), liveness_(std::move(liveness)) {} + + // Creates and returns a new BufferAllocation. Ownership is maintained + // internally. The allocation initially has only the given LogicalBuffer + // assigned to it. `is_thread_local` indicates whether this buffer needs to be + // thread-local. + BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size, + bool is_thread_local); + + // Adds a LogicalBuffer to the set assigned to the given allocation. + void AddAssignment(const LogicalBuffer& buffer, BufferAllocation* allocation); + + // Returns the BufferLiveness object used to construct this assignment. + const BufferLiveness& liveness() { return *liveness_; } + + // Convenience function which returns the PointsToSet for the given + // instruction. Extracted from the liveness object. + const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const; + + // Mutable accessors for allocations. + BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer); + BufferAllocation* GetMutableAllocation(BufferAllocation::Index index); + + // The vector of buffer allocations. Indexed by BufferAllocation::Index. + std::vector allocations_; + + // Maps Buffers to the index of the BufferAllocation which holds the buffer. + std::map + allocation_index_for_buffer_; + + const HloModule* module_; + std::unique_ptr liveness_; + + TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); +}; + +// A class which constructs a buffer assignment. +class BufferAssigner { + public: + // Build and return a BufferAssignment for the given module. The given + // HloOrdering is used to determine buffer liveness. buffer_size is a function + // which returns the size of a LogicalBuffer. If hlos_to_allocate is not null + // then only instructions in this vector are considered for buffer + // assignment. If hlos_to_allocate is null then all instructions are + // considered. If 'colocate_related_buffers' is true, related LogicalBuffers + // will be colocated in the same allocation (i.e buffers for while result + // will share an allocation with buffers related to that same while + // instruction: init operand, condition/body parameter and body result). + using BufferSizeFunction = std::function; + static StatusOr> Run( + const HloModule* module, std::unique_ptr hlo_ordering, + BufferSizeFunction buffer_size, bool colocate_related_buffers, + const std::vector* hlos_to_allocate = nullptr); + + // Overload of Run which uses ShapeUtil::ByteSizeOf to determine buffer size + // and assigns buffers to all HLO instructions in the module. + static StatusOr> Run( + const HloModule* module, std::unique_ptr hlo_ordering, + int64 pointer_size); + + private: + explicit BufferAssigner(BufferSizeFunction buffer_size, + bool colocate_related_buffers) + : buffer_size_(std::move(buffer_size)), + colocate_related_buffers_(colocate_related_buffers) {} + virtual ~BufferAssigner() = default; + + // Create a buffer assignment. + StatusOr> CreateAssignment( + const HloModule* module, std::unique_ptr hlo_ordering, + const std::vector* hlos_to_allocate = nullptr); + + // Assigns buffers to the instructions in the given computation. "assignment" + // is modified to reflect the new buffer assignments. If is_thread_local is + // true, then all assigned buffers have the is_thread_local flag set to + // true. If hlos_to_allocate is not null it indicates which HLOs to include in + // buffer assignment. If null, all instructions in the computation are + // included. + tensorflow::Status AssignBuffersForComputation( + const HloComputation* computation, bool is_thread_local, + const std::unordered_set* hlos_to_allocate, + BufferAssignment* assignment); + + // Tries to assign the given instruction to the given buffer. Returns if the + // assignment was successful. + bool MaybeAssignBuffer(BufferAllocation* allocation, + const LogicalBuffer& buffer, + BufferAssignment* assignment); + + using ColocatedBufferSet = std::vector; + + // Returns a vector of ColocatedBufferSet objects, where each + // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' + // which should be colocated in the same buffer allocation. + std::vector BuildColocatedBufferSets( + const HloModule* module, const TuplePointsToAnalysis& points_to_analysis); + + // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the + // same set to the same buffer allocation in 'assignment'. + void AssignColocatedBufferSets( + const std::vector& colocated_buffer_sets, + BufferAssignment* assignment); + + // Checks that points-to set of 'instruction' is unambiguous and distinct + // (ensured by CopyInsertion), then adds buffer from point-to set at 'index' + // to 'colocated_buffer_set'. + void AddBufferToColocatedBufferSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + BufferAssigner::ColocatedBufferSet* colocated_buffer_set); + + const HloModule* module_; + + // Function which returns the buffer size for a given shape. + BufferSizeFunction buffer_size_; + + // Indicates whether related buffers should share the same buffer allocation. + const bool colocate_related_buffers_; + + // Set of colocated buffers populated in AssignColocatedBufferSets. + std::unordered_set colocated_buffers_; + + // Set of allocations containing colocated buffers. + std::unordered_set colocated_buffer_allocations_; + + TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc new file mode 100644 index 0000000000..56138a7ee6 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -0,0 +1,1051 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +namespace { + +// DFS visitor that collects the instructions referenced by a computation +// without descending into nested computations, i.e., only from the operands. +class InstructionListVisitor : public DfsHloVisitorWithDefault { + public: + explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {} + + Status DefaultAction(HloInstruction* hlo) override { + // For each instruction, just push it on the list after walking the + // operands. + instructions_.push_back(hlo); + VLOG(0) << "List instruction " << hlo->ToString(); + return Status::OK(); + } + + std::vector GetInstructions() { return instructions_; } + + private: + // The instruction root of the computation. + const HloInstruction* root_; + + // The full set of instructions found (may be duplicates, e.g., kParameter). + std::vector instructions_; + + TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor); +}; + +const std::vector GetInstructions(HloInstruction* root) { + InstructionListVisitor main_list(root); + TF_CHECK_OK(root->Accept(&main_list)); + return main_list.GetInstructions(); +} + +class BufferAssignmentTest : public HloTestBase { + protected: + BufferAssignmentTest() : computation_tracker_() {} + ~BufferAssignmentTest() override {} + + // Builds an x+1.0 computation to use in a Map. + std::unique_ptr BuildMapComputationPlus1(const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); + return builder.Build(); + } + + // Builds a simple compare-to-limit (x < 4) computation for a While. + // + // condition: + // const4[s32] -----------------------------------\ + // \ + // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than + // + std::unique_ptr BuildWhileConditionComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto const4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); + auto index = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4)); + return builder.Build(); + } + + // Builds a simple body computation for a While. + // + // body: + // constv[f32[4]] --------------------------------------\ + // \ + // /--- get-tuple-elementv[1] --- addv ---\ + // param[(s32,f32[4])] ---| tuple + // \--- get-tuple-elementc[0] --- addc ---/ + // / + // const1[s32] -----------------------------------------/ + // + std::unique_ptr BuildWhileBodyComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto constv = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); + auto indexc = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(const1->shape(), param, 0)); + auto addc = builder.AddInstruction(HloInstruction::CreateBinary( + indexc->shape(), HloOpcode::kAdd, indexc, const1)); + auto indexv = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(constv->shape(), param, 1)); + auto addv = builder.AddInstruction(HloInstruction::CreateBinary( + constv->shape(), HloOpcode::kAdd, indexv, constv)); + builder.AddInstruction(HloInstruction::CreateTuple({addc, addv})); + return builder.Build(); + } + + // Verifies that the given instruction hlo has a valid input buffer assigned, + // i.e., the parameter number matches the op's. + const BufferAllocation& GetAssignedInputAllocation( + const BufferAssignment& buffers, HloInstruction* hlo) { + LOG(INFO) << "Checking input: " << hlo->ToString(); + const BufferAllocation& buffer = + *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie(); + EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number()); + return buffer; + } + + // Verifies that the given instruction hlo has a valid output buffer + // assigned, and returns it. + const BufferAllocation& GetAssignedOutputAllocation( + const BufferAssignment& buffers, HloInstruction* hlo) { + LOG(INFO) << "Checking output: " << hlo->ToString(); + const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo); + return buffer; + } + + // Returns the allocation for the given instruction. + const BufferAllocation& GetAllocation(const BufferAssignment& buffers, + const HloInstruction* hlo, + const ShapeIndex& index) { + return *buffers.GetUniqueAllocation(hlo, index).ConsumeValueOrDie(); + } + const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers, + const HloInstruction* hlo) { + return *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie(); + } + + // Verifies that all instructions in the given instruction list except + // kConstant have assigned buffers, and returns their total size. If min_index + // and max_index are not nullptr, the minimum and maximum buffer indices in + // the assignment are written into them. + int64 ValidateBuffers(const std::vector& instructions, + const BufferAssignment& buffers) { + // Verifies all instructions have buffers, and gets the index ranges. + for (const HloInstruction* hlo : instructions) { + if (!buffers.HasTopLevelAllocation(hlo)) { + // If `hlo` has no assigned buffer, it is either a constant or a nested + // parameter. + EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() || + HloOpcode::kParameter == hlo->opcode()); + continue; + } + } + + // Gets the total size of all buffers assigned. + int64 total_size = 0; + for (auto& allocation : buffers.Allocations()) { + total_size += allocation.size(); + } + return total_size; + } + + // Returns true if the buffers assigned to instructions in "a" are distinct + // from the buffers assigned to those in "b" (ie, intersection is empty). + bool BuffersDistinct(const std::vector& a, + const std::vector& b, + const BufferAssignment& assignment) { + std::set a_buffers; + for (const HloInstruction* instruction : a) { + if (assignment.HasTopLevelAllocation(instruction)) { + a_buffers.insert(assignment.GetUniqueTopLevelAllocation(instruction) + .ConsumeValueOrDie() + ->index()); + } + } + + for (const HloInstruction* instruction : b) { + if (assignment.HasTopLevelAllocation(instruction)) { + if (a_buffers.count(assignment.GetUniqueTopLevelAllocation(instruction) + .ConsumeValueOrDie() + ->index())) { + return false; + } + } + } + return true; + } + + // Computation tracker for nested computations. + ComputationTracker computation_tracker_; + + // Shapes for use in the examples. + Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); + Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); + Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4}); + Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10}); + Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100}); + Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10}); + Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_}); + Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_}); +}; + +namespace { +std::unique_ptr RunBufferAssignment(HloModule* module) { + return BufferAssigner::Run(module, MakeUnique(module), + /*pointer_size=*/sizeof(void*)) + .ConsumeValueOrDie(); +} +} + +// Tests a computation consisting of a single scalar constant node. +TEST_F(BufferAssignmentTest, ScalarConstant) { + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // Check that the constant does not have a buffer assigned. + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); +} + +TEST_F(BufferAssignmentTest, BufferForConst) { + // Addition of two vector constants: checks that internal constant nodes have + // no buffers assigned, and their consumer has a buffer. + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // The two constant nodes have no buffers assigned. + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); + // The add node has an output buffer. + GetAssignedOutputAllocation(*buffers, add); +} + +TEST_F(BufferAssignmentTest, BufferForOutputConst) { + // This computation copies a constant to output. + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // The copy node now has an output buffer. + GetAssignedOutputAllocation(*buffers, copy); +} + +TEST_F(BufferAssignmentTest, Basic) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Distinct input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); + EXPECT_NE(param0_buffer.index(), param1_buffer.index()); + + // The mul node has a valid buffer assigned, doesn't share with input. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + + // The add node can reuse the mul node's buffer. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_EQ(add_buffer.index(), add_buffer.index()); + + // The sub node has a valid output buffer assigned. + GetAssignedOutputAllocation(*buffers, sub); +} + +TEST_F(BufferAssignmentTest, MultipleUsersForNode) { + // This is similar to the Basic test, with the difference that (sub) is + // another user of (mul)'s result, so (mul)'s buffer cannot be reused for + // (add)'s output. + // + // paramscalar -------\ /-----------\ + // \ / \ + // param0[100] ------- (mul) -- (add) -- (sub) + // / + // param1[100] ----------------/ + // + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_index.index()); + EXPECT_NE(param0_buffer.index(), param1_index.index()); + + // The mul node had a buffer allocated. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + + // Now the add node can't reuse the mul node's buffer. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_NE(add_buffer.index(), mul_buffer.index()); + + // Log size information for inspection. + const std::vector level0 = GetInstructions(sub); + int64 size0 = ValidateBuffers(level0, *buffers); + LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() + << " for " << level0.size() << " instructions; " + << "total buffer size " << size0; +} + +TEST_F(BufferAssignmentTest, TrivialMap) { + // This tests a trivial x+1 map as the only operation. + // + // param0[100x10] ---> (map x+1) + // + // Builds the map function. + auto module = MakeUnique(TestName()); + auto map_computation = + module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + auto inner_last = map_computation->root_instruction(); + + // Creates the main kernel and verifies instruction counts. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32a100x10_, "")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); + const std::vector level0 = GetInstructions(map); + EXPECT_EQ(2, level0.size()) << "Invalid main kernel size"; + const std::vector level1 = GetInstructions(inner_last); + EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; + + module->AddEntryComputation(builder.Build()); + + // Assigns buffers and fetches sizes. + auto buffers = RunBufferAssignment(module.get()); + int64 size0 = ValidateBuffers(level0, *buffers); + int64 size1 = ValidateBuffers(level1, *buffers); + + // Both algorithms assign the map's buffer before processing the embedded + // computation, so we can verify that the buffers aren't shared between them + // by checking: + EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers)) + << "Reuse between main kernel and embedded mapping."; + + // An input buffer was assigned for the parameter. + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + + // An output buffer was assigned for the map. + BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map); + EXPECT_NE(param0_buffer.index(), map_buffer.index()); + + // The final computation node of the map is an add of an f32 parm and a + // constant. + EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode()); + const BufferAllocation& inner_add_buffer = + GetTopLevelAllocation(*buffers, inner_last); + EXPECT_NE(inner_add_buffer.index(), map_buffer.index()); + + // Log size information for inspection. + LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() + << " for " << level0.size() + level1.size() << " instructions; " + << "total buffer size " << size0 + size1; +} + +TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { + // Make sure that the input buffer of a reduce cannot be reused for its + // output. (Reuse is not safe in the general case, as it reshapes and some + // out-of-order reductions could overwrite an element before a use.) + // + // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) + auto module = MakeUnique(TestName()); + auto reduce_computation = + module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32a100x10_, "")); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + /*shape=*/f32vec10_, + /*operand=*/exp2, + /*init_value=*/const0, + /*dimensions_to_reduce=*/{0}, reduce_computation)); + auto exp3 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce)); + + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + const std::vector instrs = GetInstructions(exp3); + ValidateBuffers(instrs, *buffers); + + const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1); + const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2); + const BufferAllocation& reduce_buffer = + GetTopLevelAllocation(*buffers, reduce); + + // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity + // checking. + EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index()); + + // The buffer of exp2 cannot be used for reduce, even though it's the only + // operand. + EXPECT_NE(exp2_buffer.index(), reduce_buffer.index()); +} + +TEST_F(BufferAssignmentTest, ExampleWhile) { + // This tests a While loop example from the ir_semantics document. + // + // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation. + // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation. + // + // const3[s32] -------\ + // const4[f32[4]] --- tuple --- while[condition, body] + // + // Builds the nested condition and body. + auto module = MakeUnique(TestName()); + auto condition_computation = + module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); + auto body_computation = + module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update")); + + // Creates the main kernel and verifies instruction counts. + auto builder = HloComputation::Builder(TestName()); + auto const3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); + auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( + t_s32_f32v4_, condition_computation, body_computation, tuple)); + + const std::vector level0 = GetInstructions(while_op); + EXPECT_EQ(4, level0.size()) << "Invalid while kernel size"; + const std::vector levelc = + GetInstructions(condition_computation->root_instruction()); + EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size"; + const std::vector levelb = + GetInstructions(body_computation->root_instruction()); + EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; + + module->AddEntryComputation(builder.Build()); + + // Assigns buffers and fetches sizes. + auto buffers = RunBufferAssignment(module.get()); + int64 size0 = ValidateBuffers(level0, *buffers); + int64 sizec = ValidateBuffers(levelc, *buffers); + int64 sizeb = ValidateBuffers(levelb, *buffers); + + // BufferAssignment will assign a single allocation for the following + // instructions: while, while.cond.param, while.body.param, while.body.result. + EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers)) + << "Should be reuse between main kernel and embedded condition."; + EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers)) + << "Should be reuse between embedded condition and body."; + // Expect buffer reuse between main kernel and body computation. + EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers)) + << "Should be reuse between main kernel and embedded body."; + + // The final computation node of the while body is a tuple of s32 and + // f32[4] adds. + HloInstruction* body_root = body_computation->root_instruction(); + EXPECT_EQ(HloOpcode::kTuple, body_root->opcode()); + + // Check that buffer for each subshape of 'while_op' shares allocation with + // corresponding buffer from while body computation at same index. + TF_CHECK_OK(ShapeUtil::ForEachSubshape( + while_op->shape(), + [this, &buffers, while_op, body_root](const Shape& /*subshape*/, + const ShapeIndex& index) { + auto while_op_allocation = GetAllocation(*buffers, while_op, index); + auto body_root_allocation = GetAllocation(*buffers, body_root, index); + EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index()); + return Status::OK(); + })); + + // Log size information for inspection. + LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() + << " for " << level0.size() + levelc.size() + levelb.size() + << " instructions; total buffer size " << size0 + sizec + sizeb; +} + +TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { + // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "")); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // tanh and exp2 can reuse exp1's buffer + EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); + auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1); + EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh)); + EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2)); + EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg)); +} + +TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { + // This computation is a chain of operations which decreases in buffer size + // (via slice) then increases in size (via broadcast): + // + // param ---> (negate) ---> (slice) ---> (broadcast) + // + // The negate should share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // negate and broadcast should share a buffer. + EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); + auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast); + EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate)); + + // Slice should have its own buffer. + EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { + // This computation is identical to that in ReuseNonOperandBuffer, but the + // negate value is live until the end of the computation (due to it being an + // operand of the output tuple) preventing reuse. + // + // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple) + // \-----------------------------------/ + // + // The negate should not share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); + builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { + // This computation is identical to that in ReuseNonOperandBuffer, but the + // negate value is placed into a tuple which lives to the end of the + // computation. This extends the live range of negate's buffer preventing + // reuse due to buffer aliasing. + // + // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple) + // \-----------------------------------/ + // + // The negate should not share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate})); + auto tuple_element = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10})); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); + builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { + // This computation is very similar to ReuseNonOperandBuffer except the + // broadcast has a smaller output than the negate. This should block reuse of + // negate's buffer by broadcast because the output buffer(s) of a computation + // should be exactly sized for the value. + // + // param ---> (negate) ---> (slice) ---> (broadcast) + // + // The negate should *not* share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + // Negate output is 100 elements. + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + // Broadcast output is 40 elements. + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { + // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast + // output is exactly the same size as the negate (rather than being + // smaller). This enables reuse of negate's buffer by the broadcast because + // the output buffer will be sized exactly to its value. + // + // param ---> (negate) ---> (slice) ---> (broadcast) + // + // The negate should *not* share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + // Negate output is 100 elements. + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + // Broadcast output is 40 elements. + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // negate and broadcast should share a buffer. + EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); + auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast); + EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate)); + + // Slice should have its own buffer. + EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { + // This computation is very similar to ReuseNonOperandBuffer except the + // broadcast has a smaller output than the negate, and the broadcast is + // contained in the computation output as a tuple element. This should block + // reuse of the negate's buffer by the broadcast because the output buffer(s) + // of a computation should be exactly sized for the value. This includes those + // buffers aliased in the output (eg, contained as tuple elements). + // + // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple) + // + // The negate should *not* share a buffer with broadcast. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + // Negate output is 100 elements. + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + // Broadcast output is 40 elements. + auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); + builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // The instructions should not share buffers. + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, negate)); + EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), + GetTopLevelAllocation(*assignment, slice)); + EXPECT_NE(GetTopLevelAllocation(*assignment, negate), + GetTopLevelAllocation(*assignment, slice)); +} + +TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { + // Verify that buffers for embedded computations are properly marked as + // thread-local and that embedded parameters are not marked as + // is_entry_computation_parameter. + auto module = MakeUnique(TestName()); + auto vec_shape = ShapeUtil::MakeShape(F32, {42}); + auto scalar_shape = ShapeUtil::MakeShape(F32, {}); + + // Create a scalar computation to use in a map. + auto map_builder = HloComputation::Builder(TestName() + "_map"); + auto map_param = map_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "map_param")); + auto map_root = map_builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param)); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + + // Create a vector computation to use in a kCall. + auto call_builder = HloComputation::Builder(TestName() + "_call"); + auto call_param = call_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_shape, "vec_param")); + auto call_root = call_builder.AddInstruction( + HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param)); + auto call_computation = module->AddEmbeddedComputation(call_builder.Build()); + + // Create entry computation which kCalls call_computation and then calls map + // with map_computation on the result. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_shape, "param")); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(vec_shape, {param}, call_computation)); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(vec_shape, {call}, map_computation)); + module->AddEntryComputation(builder.Build()); + + auto assignment = RunBufferAssignment(module.get()); + + // Allocations for the map computation should be thread-local and not + // live-out. + auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param); + EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(map_param_alloc.maybe_live_out()); + EXPECT_TRUE(map_param_alloc.is_thread_local()); + + auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root); + EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(map_root_alloc.maybe_live_out()); + EXPECT_TRUE(map_root_alloc.is_thread_local()); + + // Allocations for the call computation should not be thread-local and not + // live-out. + auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); + EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(call_param_alloc.maybe_live_out()); + EXPECT_FALSE(call_param_alloc.is_thread_local()); + + auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root); + EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(call_root_alloc.maybe_live_out()); + EXPECT_FALSE(call_root_alloc.is_thread_local()); + + // Entry computation allocations can be marked liveout and + // is_entry_computation_parameter. + auto& param_alloc = GetTopLevelAllocation(*assignment, param); + EXPECT_TRUE(param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(param_alloc.maybe_live_out()); + EXPECT_FALSE(param_alloc.is_thread_local()); + + auto& map_alloc = GetTopLevelAllocation(*assignment, map); + EXPECT_FALSE(map_alloc.is_entry_computation_parameter()); + EXPECT_TRUE(map_alloc.maybe_live_out()); + EXPECT_FALSE(map_alloc.is_thread_local()); +} + +TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { + // Test a computation that returns a tuple parameter. + auto builder = HloComputation::Builder(TestName()); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {42})}), + "param0")); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // There should be four allocations: one for vector of pointers, and one for + // each tuple element. + EXPECT_EQ(4, assignment->Allocations().size()); + + // Verify each buffer allocation is marked as an entry computation parameter + // and is liveout. + TF_CHECK_OK(ShapeUtil::ForEachSubshape( + tuple_param->shape(), + [this, &assignment, tuple_param](const Shape& /*subshape*/, + const ShapeIndex& index) { + auto allocation = GetAllocation(*assignment, tuple_param, index); + EXPECT_TRUE(allocation.is_entry_computation_parameter()); + EXPECT_EQ(0, allocation.parameter_number()); + EXPECT_TRUE(allocation.maybe_live_out()); + return Status::OK(); + })); +} + +TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { + // Test a computation which returns a GetElementTuple of a nested tuple + // parameter. + auto builder = HloComputation::Builder(TestName()); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), + ShapeUtil::MakeShape(S32, {101})})}), + "param0")); + auto tuple_element = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // Only some of the elements of the input param are liveout. + EXPECT_FALSE( + GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out()); + // Tuple element at index={1} is live out because GetTupleElement({1}) + // forwards a pointer to this allocation (instead of defining its own buffer). + EXPECT_TRUE( + GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out()); + EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}) + .maybe_live_out()); + EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}) + .maybe_live_out()); + + // The GetTupleElement output is liveout. + EXPECT_TRUE( + GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out()); + + // Verify that the GetTupleElement allocations of its elements match the + // corresponding tuple parameter allocations because they alias. + EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}), + GetAllocation(*assignment, tuple_element, /*index=*/{0})); + EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}), + GetAllocation(*assignment, tuple_element, /*index=*/{1})); + + // GetTupleElement forwards a pointer to its underlying buffer, so verify + // that it has the same allocation than the corresponding parameter element. + EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}), + GetTopLevelAllocation(*assignment, tuple_element)); +} + +// TODO(b/32248867): Enable when buffer assignment gives allocations to +// constants. +TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { + // Test that a tuple constant which is forwarded to the computation output is + // properly handled. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}))); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_EQ(3, assignment->Allocations().size()); +} + +TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { + // Test a computation which returns a tuple custom call value. + auto builder = HloComputation::Builder(TestName()); + auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeShape(S32, {101})}), + /*operands=*/{}, /*custom_call_target=*/"foo_function")); + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_EQ(3, assignment->Allocations().size()); + EXPECT_TRUE( + GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out()); + EXPECT_TRUE( + GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out()); + EXPECT_TRUE( + GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out()); +} + +TEST_F(BufferAssignmentTest, BitcastAsOutput) { + // Test a computation which returns a bitcast value. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "param")); + auto bitcast = builder.AddInstruction( + HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // Bitcast should get the same allocation as the param. + EXPECT_EQ(1, assignment->Allocations().size()); + EXPECT_EQ(GetTopLevelAllocation(*assignment, param), + GetTopLevelAllocation(*assignment, bitcast)); +} + +TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { + // Test a computation with an output that has an ambiguous points-to set. This + // is constructed using a select among tuple shapes. + auto builder = HloComputation::Builder(TestName()); + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})}); + + auto tuple_param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param0")); + auto tuple_param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, tuple_shape, "param1")); + auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {}), "param1")); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get()); + + // Select shallow copies one of its operands so it defines its own top-level + // buffer and receives its own allocation. + auto select_alloc = GetTopLevelAllocation(*assignment, select); + EXPECT_EQ(1, select_alloc.assigned_buffers().size()); + EXPECT_EQ(select, select_alloc.assigned_buffers()[0]->instruction()); + + // The buffer for the tuple element of the select is forwarded from one its + // operands which cannot be determined statically. Therefore its allocation + // should include the allocations of both of the elements in the parameters. + auto element_allocations = assignment->GetAllocations(select, /*index=*/{0}); + EXPECT_EQ(2, element_allocations.size()); + EXPECT_MATCH(testing::SetToVec(element_allocations), + testing::UnorderedMatcher( + *assignment->GetUniqueAllocation(tuple_param0, /*index=*/{0}) + .ConsumeValueOrDie(), + *assignment->GetUniqueAllocation(tuple_param1, /*index=*/{0}) + .ConsumeValueOrDie())); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc new file mode 100644 index 0000000000..d4faca7cd8 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines the data returned by the XLA buffer assignment packages. + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) + : module_(module) {} + +bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // Instructions in different computations are unordered. + if (a->parent() != b->parent()) { + return false; + } + // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. + return strict_predecessors_.at(b->parent())->IsReachable(b, a); +} + +string PredecessorHloOrdering::ToStringHelper(const string& name) const { + std::vector pieces; + pieces.push_back(name); + for (auto& computation : module_->computations()) { + pieces.push_back(tensorflow::strings::Printf("computation %s:", + computation->name().c_str())); + const auto all = computation->MakeInstructionPostOrder(); + for (auto instruction : all) { + pieces.push_back(tensorflow::strings::Printf( + " %s strict predecessors:", instruction->name().c_str())); + for (auto predecessor : all) { + if (strict_predecessors_.at(computation.get()) + ->IsReachable(instruction, predecessor)) { + pieces.push_back( + tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + } + } + } + } + return tensorflow::str_util::Join(pieces, "\n"); +} + +DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) + : PredecessorHloOrdering(module) { + // Compute predecessor relationships between all instructions to determine + // ordering based on dependencies. ExecutesBefore will return true iff there + // exists a path in the HLO computation graph from 'a' to 'b'. + for (auto& computation : module->computations()) { + strict_predecessors_.emplace(computation.get(), + computation->ComputeTransitiveOperands()); + } +} + +string DependencyHloOrdering::ToString() const { + return ToStringHelper("DependencyHloOrdering"); +} + +SequentialHloOrdering::SequentialHloOrdering( + const HloModule* module, const HloModuleSequence& module_sequence) + : module_(module) { + // Create a map from instruction to its order position. + for (auto computation_order : module_sequence) { + const std::vector& order = computation_order.second; + for (int i = 0; i < order.size(); ++i) { + DCHECK_EQ(0, order_position_.count(order[i])); + order_position_.emplace(order[i], i); + } + } +} + +bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // Instructions in different computations are unordered. + if (a->parent() != b->parent()) { + return false; + } + // If either instruction is not in the order, then 'a' and 'b' are unordered. + if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { + return false; + } + return order_position_.at(a) < order_position_.at(b); +} + +string SequentialHloOrdering::ToString() const { + std::vector pieces; + pieces.push_back("SequentialHloOrdering"); + for (auto& computation : module_->computations()) { + pieces.push_back(tensorflow::strings::Printf("computation %s order:", + computation->name().c_str())); + // Gather all instructions in the module sequence for this computation and + // sort them by their position. + std::vector instructions; + for (auto& instruction_position : order_position_) { + const HloInstruction* instruction = instruction_position.first; + if (instruction->parent() == computation.get()) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [this](const HloInstruction* a, const HloInstruction* b) { + return order_position_.at(a) < order_position_.at(b); + }); + for (auto instruction : instructions) { + pieces.push_back( + tensorflow::strings::Printf(" %s", instruction->name().c_str())); + } + } + return tensorflow::str_util::Join(pieces, "\n"); +} + +/* static */ +StatusOr> BufferLiveness::Run( + const HloModule* module, std::unique_ptr hlo_ordering) { + std::unique_ptr liveness( + new BufferLiveness(module, std::move(hlo_ordering))); + TF_RETURN_IF_ERROR(liveness->Analyze()); + return std::move(liveness); +} + +tensorflow::Status BufferLiveness::Analyze() { + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); + for (auto& computation : module_->computations()) { + // Gather all instructions whose buffers might alias other instructions into + // the set aliased_buffers_. This includes those contained as a tuple + // element in other instruction's output. + for (const auto& instruction : computation->instructions()) { + for (const LogicalBuffer* aliased_buffer : + points_to_analysis_->GetPointsToSet(instruction.get()) + .CreateFlattenedSet()) { + if (aliased_buffer->instruction() != instruction.get()) { + aliased_buffers_.insert(aliased_buffer); + } + } + } + + if (computation.get() == module_->entry_computation()) { + for (const LogicalBuffer* live_out_buffer : + points_to_analysis_->GetPointsToSet(computation->root_instruction()) + .CreateFlattenedSet()) { + maybe_live_out_buffers_.insert(live_out_buffer); + } + } + } + + XLA_VLOG_LINES(3, ToString()); + return tensorflow::Status::OK(); +} + +string BufferLiveness::ToString() const { + std::vector pieces; + pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):", + module_->name().c_str())); + pieces.push_back("HloOrdering:"); + pieces.push_back(hlo_ordering_->ToString()); + pieces.push_back(tensorflow::strings::Printf("Aliased buffers:")); + for (const LogicalBuffer* buffer : aliased_buffers_) { + pieces.push_back( + tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + } + pieces.push_back(tensorflow::strings::Printf("Live out buffers:")); + for (const LogicalBuffer* buffer : maybe_live_out_buffers_) { + pieces.push_back( + tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + } + return tensorflow::str_util::Join(pieces, "\n"); +} + +// Returns false if 'user' cannot possibly use the buffer at 'index' in +// 'operand'. Returns true otherwise. +// Precondition: 'operand' is an operand of 'user'. +bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index, + HloInstruction* user) { + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return false; + } + return true; +} + +bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, + const LogicalBuffer& b) const { + TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); + TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b)); + + if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) { + return false; + } + + // Every user of 'a' must be a predecessor of 'b' or 'b' itself. + for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { + for (auto user : alias.instruction()->users()) { + if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user)) { + continue; + } + if (user != b.instruction() && + !hlo_ordering_->ExecutesBefore(user, b.instruction())) { + return false; + } + } + } + + // If 'b' is a user of 'a' then the buffers interfere if b is not an + // elementwise operation emitting the same shape/layout as 'a'. + for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { + if (alias.instruction()->users().count(b.instruction()) > 0 && + (!ShapeUtil::Equal(alias.instruction()->shape(), + b.instruction()->shape()) || + !b.instruction()->IsElementwise())) { + return false; + } + } + return true; +} + +bool BufferLiveness::MayInterfere(const LogicalBuffer& a, + const LogicalBuffer& b) const { + return (!live_range_strictly_before(a, b) && + !live_range_strictly_before(b, a)); +} + +bool BufferLiveness::MaybeLiveOut(const LogicalBuffer& buffer) const { + // Verify that a buffer is actually defined at the given instruction/index + // (eg, its not an alias of another buffer such as occurs with a bitcast). + TF_CHECK_OK(points_to_analysis_->VerifyBuffer(buffer)); + return maybe_live_out_buffers_.count(&buffer); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h new file mode 100644 index 0000000000..964f558c8c --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Abstract base class for describing a partial ordering of HLO +// instructions. Used to determine live range overlap of HLO instruction output +// buffers. +class HloOrdering { + public: + HloOrdering() = default; + virtual ~HloOrdering() = default; + + // Returns true if instruction 'a' executes before instruction 'b'. This is + // not reflexive, that is, an instruction does not execute before itself. + virtual bool ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const = 0; + virtual string ToString() const = 0; +}; + +// Base class for partial orderings implemented by a map of strict predecessors +// for each instruction. Subclasses should fill in strict_predecessors_. +class PredecessorHloOrdering : public HloOrdering { + public: + ~PredecessorHloOrdering() override = default; + + // Returns true if instruction 'a' executes before instruction 'b'. + // Instructions in different computations are not ordered. + bool ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const override; + + protected: + explicit PredecessorHloOrdering(const HloModule* module); + string ToStringHelper(const string& name) const; + + const HloModule* module_; + + // For each each computation in the module, this is the set of the + // instruction's strict predecessors. An instruction is not an element of its + // own strict predecessor set. + // + // Subclasses should fill this in to define the desired ordering. + tensorflow::gtl::FlatMap> + strict_predecessors_; +}; + +// An HLO ordering based on data dependencies in the HLO graph. In this partial +// order, instruction A executes before instruction B only if there is a path +// from A to B in the HLO graph. For example, given the following graph: +// +// param +// / \ +// negate exp +// \ / +// add +// +// DependencyHloOrdering gives the following executes-before relations: +// param executes before negate, exp, and add +// negate executes before add +// exp executes before add +// add executes before nothing +// negate and exp are not ordered because the dependencies allow either to +// execute before the other (or in parallel). DependencyHloOrdering ordering +// allows maximum parallelism and enables any execution order which satisfies +// data dependencies. This requires pessimistic assumptions about buffer live +// ranges and can result in more memory used than more constrained orderings. +class DependencyHloOrdering : public PredecessorHloOrdering { + public: + explicit DependencyHloOrdering(const HloModule* module); + ~DependencyHloOrdering() override = default; + + string ToString() const override; +}; + +// An HLO ordering based on a total order of instructions in each computation. +// The computation total order is a sequencing of all of its instructions in +// the computation (eg, {inst0, inst1, inst2,...}) as in single-threaded +// execution. For example, given the following HLO graph: +// +// param +// / \ +// negate exp +// \ / +// add +// +// and the following sequence: +// +// {param, negate, exp, add} +// +// SequentialHloOrdering gives the following executes-before relations: +// param executes before negate, exp, and add +// negate executes before exp and add +// exp executes before add +// add executes before nothing +// This is more constrained than DependencyHloOrdering in this example because +// negate and exp are ordered (negate before exp). This enables param to share +// the same buffer as exp (param buffer is dead after exp). Generally, this +// ordering enables more buffer sharing (reduced memory usage) because buffer +// interference is reduced relative to DependencyHloOrdering. +class SequentialHloOrdering : public HloOrdering { + public: + // A sequence of instructions for each computation in the module. + using HloModuleSequence = + tensorflow::gtl::FlatMap>; + + SequentialHloOrdering(const HloModule* module, + const HloModuleSequence& module_sequence); + ~SequentialHloOrdering() override = default; + + // Instruction 'a' executes before 'b' if 'a' appears before 'b' in the + // instruction sequence for the computation. Instructions in different + // computations are unordered. + bool ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const override; + string ToString() const override; + + protected: + const HloModule* module_; + + // The position of every instruction in the HLO module in its respective + // computation sequence (a value of zero indicates the instruction is first in + // the sequence, etc). Instructions from all computations are contained in + // this map so more than one instruction may have the same position + // value. This is not a problem because ExecutesBefore also verifies + // instructions are in the same computation. + tensorflow::gtl::FlatMap order_position_; +}; + +// Class which computes liveness of the output buffers of HLOs and their +// interference. +class BufferLiveness { + public: + // Constructs a buffer liveness object for the given module assuming the given + // HLO instruction ordering. + static StatusOr> Run( + const HloModule* module, std::unique_ptr hlo_ordering); + + // Returns true if the live range of the buffer containing the output of 'a' + // may overlap with the live range of the buffer of 'b'. If instruction 'a' + // interferes with instruction 'b' then they cannot share the same buffer. + bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const; + + // Returns true if the buffer for the given instruction may be live out of the + // module. That is, the instruction's buffer may be included in the output of + // the entry computation. + bool MaybeLiveOut(const LogicalBuffer& buffer) const; + + // Returns the underlying points-to analysis used for this liveness analysis. + const TuplePointsToAnalysis& points_to_analysis() const { + return *points_to_analysis_; + } + + string ToString() const; + + private: + explicit BufferLiveness(const HloModule* module, + std::unique_ptr hlo_ordering) + : module_(module), hlo_ordering_(std::move(hlo_ordering)) {} + + // Perform buffer liveness analysis. This method must be called prior to + // MayInterfere or MaybeLiveOut. + tensorflow::Status Analyze(); + + // Returns true if the live range of the buffer of 'a' is strictly before the + // live range of the buffer of 'b' (they do not overlap). + bool live_range_strictly_before(const LogicalBuffer& a, + const LogicalBuffer& b) const; + + const HloModule* module_; + std::unique_ptr hlo_ordering_; + + // Set of LogicalBuffers which are aliased in the output of other + // instructions. For example, a LogicalBuffer which is inserted into a tuple + // is considered to be aliased and will be in this set. + tensorflow::gtl::FlatSet aliased_buffers_; + + // LogicalBuffers that may be live out of the entry computation. + tensorflow::gtl::FlatSet maybe_live_out_buffers_; + + std::unique_ptr points_to_analysis_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_ diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc new file mode 100644 index 0000000000..1ca5768dbe --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -0,0 +1,487 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" + +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class BufferLivenessTest : public HloTestBase { + protected: + // Returns the LogicalBuffer defined at the given instruction and + // index. CHECKs if no buffer is defined at that point. + const LogicalBuffer& GetBuffer(const BufferLiveness& liveness, + const HloInstruction* instruction, + const ShapeIndex& index) { + const std::vector& pointed_to = + liveness.points_to_analysis() + .GetPointsToSet(instruction) + .element(index); + CHECK_EQ(1, pointed_to.size()); + CHECK_EQ(instruction, pointed_to[0]->instruction()); + CHECK(index == pointed_to[0]->index()); + return *pointed_to[0]; + } + + // Returns true if the top-level buffers for instructions 'a' and 'b' may + // interfere. Precondition: 'a' and 'b' are array-shaped. + bool InstructionsMayInterfere(const BufferLiveness& liveness, + HloInstruction* a, HloInstruction* b) { + EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); + EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + return liveness.MayInterfere( + GetBuffer(liveness, /*instruction=*/a, /*index=*/{}), + GetBuffer(liveness, /*instruction=*/b, /*index=*/{})); + } + + // Returns true if the tuple elements at 'index' for instructions 'a' and 'b' + // may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal + // tuple element sub-shapes. + bool TupleElementsMayInterfere(const BufferLiveness& liveness, + HloInstruction* a, HloInstruction* b, + const ShapeIndex& index) { + // Check that top-level shapes are tuple and tuple element shapes are equal. + EXPECT_TRUE(ShapeUtil::IsTuple(a->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(b->shape())); + EXPECT_TRUE( + ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index), + ShapeUtil::GetSubshape(b->shape(), index))); + // Lookup PointsTo set for instructions 'a' and 'b'. + auto& points_to_analysis = liveness.points_to_analysis(); + const std::vector& points_to_a = + points_to_analysis.GetPointsToSet(a).element(index); + const std::vector& points_to_b = + points_to_analysis.GetPointsToSet(b).element(index); + // Make sure PointsTo sets for 'a' and 'b' are unambiguous. + EXPECT_EQ(1, points_to_a.size()); + EXPECT_EQ(points_to_a.size(), points_to_b.size()); + // Check interference. + return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]); + } + + // Returns true if the top-level buffers for the given instruction maybe + // liveout of the entry computation. + // Precondition: instruction is array-shaped. + bool InstructionMaybeLiveOut(const BufferLiveness& liveness, + HloInstruction* instruction) { + return liveness.MaybeLiveOut( + GetBuffer(liveness, instruction, /*index=*/{})); + } + + const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42}); +}; + +TEST_F(BufferLivenessTest, ElementwiseChain) { + // A simple chain of elementwise operations. No buffers should interfere. + // + // param --> negate -> exp -> log + // + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // No buffers should interfere. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log)); + + // Buffers should interfere with itself. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp)); + + // Only log is live out. + EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param)); + EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate)); + EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp)); + EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log)); +} + +TEST_F(BufferLivenessTest, NonElementwiseOperand) { + // A chain of operations with one elementwise and one non-elementwise. The + // elementwise op should not interfere with its operand, while the + // non-elementwise op should interfere. + // + // param --> negate -> reverse + // + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // No buffers should interfere. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); +} + +TEST_F(BufferLivenessTest, OverlappedBuffers) { + // Verify simultaneously live buffers interfere (exp and negate). + // + // param --> negate -> add + // \---> exp -----/ + // + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); +} + +TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { + // Identical to the test OverlappedBuffer but using a sequential ordering of + // HLO instructions. + // + // param --> negate -> add + // \---> exp -----/ + // + // Sequential order: + // param, negate, exp, add + // + // Liveness is identical to the DependencyHloOrdering except that 'param' and + // exp no longer interfere. + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + SequentialHloOrdering::HloModuleSequence module_sequence; + std::vector order = {param, negate, exp, add}; + module_sequence.emplace(computation, order); + auto liveness = + BufferLiveness::Run(module.get(), MakeUnique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); + + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); +} + +TEST_F(BufferLivenessTest, TupleLiveOut) { + // Verify MaybeLiveOut with nested tuples. Result of computation looks like: + // + // Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))}) + // + // All values should be live out except Param. + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + auto inner_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate})); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate)); + auto outer_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // All buffers should be live out except the param + EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param)); + EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate)); + EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple)); + EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp)); + EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple)); +} + +// bitcast liveout. + +TEST_F(BufferLivenessTest, EmbeddedComputation) { + // Test MaybeLiveOut and MayInterfere for embedded computation. + auto module = MakeUnique(TestName()); + + auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); + auto embedded_param = embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_, "embedded_param")); + auto embedded_log = embedded_builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param)); + + auto embedded_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(vec_, {param}, embedded_computation)); + + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // Buffers in different computations should always interfere. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param)); + EXPECT_FALSE( + InstructionsMayInterfere(*liveness, embedded_param, embedded_log)); + + // The only buffers for which MaybeLiveOut == true are those live out + // of the entry computation. Buffers live out of embedded computations should + // return false for this method. + EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log)); + EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call)); +} + +TEST_F(BufferLivenessTest, TupleConstantLiveOut) { + // Verify non top-level elements of a nested tuple constant are properly + // marked as liveout. Computation: + // + // GetTupleElement(0, TupleConstant({{0, 1}, {3}}) + // + // Only the array buffers containing 0 and 1 are liveout of the + // computation. The buffer containing {0, 1} is copied by GetTupleElement, and + // the buffers containing {3} and 3 are dead. + auto builder = HloComputation::Builder(TestName()); + auto inner_tuple0 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}); + auto inner_tuple1 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + inner_tuple0->shape(), tuple_constant, 0)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // Only the element buffers of the tuple constant which are pointed to by + // the GetTupleElement instruction should be liveout. + EXPECT_FALSE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{}))); + EXPECT_TRUE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{0}))); + EXPECT_TRUE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0}))); + EXPECT_TRUE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1}))); + EXPECT_FALSE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{1}))); + EXPECT_FALSE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0}))); + EXPECT_FALSE(liveness->MaybeLiveOut( + GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0}))); +} + +TEST_F(BufferLivenessTest, IndependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + // Create param0 Tuple. + auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}), + "param0")); + // Create independent computations for each tuple elememt. + + // Tuple element0 computation: + // Add(GetTupleElement(tuple_param0, 0), const0) + auto tuple_element0_shape = + ShapeUtil::GetSubshape(tuple_param0->shape(), {0}); + auto tuple_element0 = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple_element0_shape, tuple_param0, 0)); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); + + // Tuple element1 computation: + // Add(GetTupleElement(tuple_param0, 1), const1) + auto tuple_element1_shape = + ShapeUtil::GetSubshape(tuple_param0->shape(), {1}); + auto tuple_element1 = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple_element1_shape, tuple_param0, 1)); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); + + // Create output tuple. + auto tuple_root = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // We compare tuple element pairs that are input/output to the computation: + // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0') + // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1') + + // Tuple output element 'add0' does not depend on input 'tuple_element1'. + // Tuple output element 'add1' does not depend on input 'tuple_element0'. + + // Both element pair does not interfere, because there is no other dependency + // on the pairs tuple input element, and so liveness can compute that all + // users of the input tuple element execute before the associated output + // tuple element. + EXPECT_FALSE( + TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0})); + EXPECT_FALSE( + TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1})); +} + +TEST_F(BufferLivenessTest, DependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + // Create param0 Tuple. + auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}), + "param0")); + // Create dependent computations for each tuple elememt. + + // Tuple element0 computation: + // Add(GetTupleElement(tuple_param0, 0), const0) + auto tuple_element0_shape = + ShapeUtil::GetSubshape(tuple_param0->shape(), {0}); + auto tuple_element0 = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple_element0_shape, tuple_param0, 0)); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); + + // Tuple element1 computation: + // Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1)) + auto tuple_element1_shape = + ShapeUtil::GetSubshape(tuple_param0->shape(), {1}); + auto tuple_element1 = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple_element1_shape, tuple_param0, 1)); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1)); + + // Create output tuple. + auto tuple_root = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + + // We compare tuple element pairs that are input/output to the computation: + // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0') + // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1') + + // The first tuple element pair output 'add0', has no dependency on second + // tuple element pairs input 'tuple_element1'. + + // The second tuple element pair output 'add1', has a dependency on first + // tuple element pairs input 'tuple_element0'. + + // The first tuple element pair does interfere, because liveness cannot + // compute that all references to 'tuple_element0' are executed before 'add0' + // (because of the depenency of 'add1' on 'tuple_element0'). + EXPECT_TRUE( + TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0})); + + // The second tuple element pair does not interfere, because there is no + // other dependency on 'tuple_element1', and so liveness can compute that + // all users execute before 'add1'. + EXPECT_FALSE( + TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1})); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc new file mode 100644 index 0000000000..b3784c36ff --- /dev/null +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/channel_tracker.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +ChannelTracker::ChannelTracker() : next_channel_(1) {} + +ChannelHandle ChannelTracker::NewChannel() { + tensorflow::mutex_lock lock(channel_mutex_); + + // Create a new channel handle with a unique value. + const ChannelHandle new_handle = AllocateHandle(); + + // Register a channel object associated with the handle. + Channel channel; + channel.has_sender = false; + channel.receiver_count = 0; + opaque_to_channel_[new_handle.handle()] = channel; + + return new_handle; +} + +Status ChannelTracker::RegisterSend(const ChannelHandle& handle) { + tensorflow::mutex_lock lock(channel_mutex_); + return RegisterSendInternal(handle); +} + +Status ChannelTracker::RegisterRecv(const ChannelHandle& handle) { + tensorflow::mutex_lock lock(channel_mutex_); + return RegisterRecvInternal(handle); +} + +ChannelHandle ChannelTracker::AllocateHandle() { + int64 handle_value = next_channel_++; + ChannelHandle result; + result.set_handle(handle_value); + return result; +} + +Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { + if (opaque_to_channel_.count(handle.handle()) == 0) { + return NotFound("channel handle not found: %lld", handle.handle()); + } + Channel& channel = opaque_to_channel_[handle.handle()]; + if (channel.has_sender) { + return FailedPrecondition("channel handle is already used by a sender"); + } + channel.has_sender = true; + return Status::OK(); +} + +Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { + if (opaque_to_channel_.count(handle.handle()) == 0) { + return NotFound("channel handle not found: %lld", handle.handle()); + } + Channel& channel = opaque_to_channel_[handle.handle()]; + // TODO(b/33942691): Allow more than 1 receivers for broadcast. + if (channel.receiver_count >= 1) { + return FailedPrecondition("channel handle is already used by a receiver"); + } + channel.receiver_count += 1; + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h new file mode 100644 index 0000000000..c7763f2ca3 --- /dev/null +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/user_computation.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Tracks channels between computations in the XLA service. Channels +// are associated with a unique handle and can be resolved from the handle for +// later use. +// +// TODO(b/34027823): Destruct channels when all the associated computations that +// communicate via each channel are destructed. +class ChannelTracker { + public: + ChannelTracker(); + + // A struct that keeps the current status of each channel. has_sender and + // receiver_count fields are initialized with false and 0 respectively when + // the struct is created and are updated by RegisterSend() and RegisterRecev() + // as Send or Recv instructions using the channel are requested. + struct Channel { + bool has_sender; + int64 receiver_count; + }; + + // Creates a new Channel object and returns the corresponding + // ChannelHandle for it. + ChannelHandle NewChannel(); + + // Informs that the given channel handle is used for a Send operation. + // Returns an error status if the handle is already used by another Send. + Status RegisterSend(const ChannelHandle& handle); + + // Informs that the given channel handle is used for a Recv operation. + // Returns an error status if the handle is already used by another Recv. + Status RegisterRecv(const ChannelHandle& handle); + + private: + // Bumps the next_channel_ number and returns the allocated number + // wrapped in a ChannelHandle. + ChannelHandle AllocateHandle() EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + + Status RegisterSendInternal(const ChannelHandle& handle) + EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + + Status RegisterRecvInternal(const ChannelHandle& handle) + EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + + // Guards the channel mapping. + tensorflow::mutex channel_mutex_; + + // The next sequence number to assign to a channel. + int64 next_channel_ GUARDED_BY(channel_mutex_); + + // Mapping from ChannelHandle value to the corresponding registered + // Channel object. + std::map opaque_to_channel_ GUARDED_BY(channel_mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc new file mode 100644 index 0000000000..b16907da9e --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compilation_cache.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +std::shared_ptr CompilationCache::Insert( + std::unique_ptr executable, + const HloModuleConfig& module_config) { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = + BuildKey(executable->entry_computation_handle(), module_config); + VLOG(2) << "inserting cache key: " << key; + if (cache_.count(key) == 0) { + cache_.emplace(key, std::move(executable)); + } else { + // Executable already exists in the cache. This can happen if two Execute + // calls for a new computation are received simultaneously by the + // service. In this case, we discard the Executable given as a parameter and + // return what is in the cache. This is necessary because the service relies + // on the cache to keep ownership of the Executable. We only want to store + // one Executable for a given computation version and we can't discard the + // executable which is in the cache because it may be in use. + executable.reset(); + } + return cache_.at(key); +} + +std::shared_ptr CompilationCache::LookUp( + const VersionedComputationHandle& versioned_handle, + const HloModuleConfig& module_config) const { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = BuildKey(versioned_handle, module_config); + VLOG(2) << "looking up cache key: " << key; + if (cache_.count(key) == 0) { + VLOG(2) << "cache key not found: " << key; + return nullptr; + } else { + std::shared_ptr result = cache_.at(key); + VLOG(2) << "hit executable with module config: " + << result->module_config().compilation_cache_key(); + return result; + } +} + +CompilationCache::CacheKey CompilationCache::BuildKey( + const VersionedComputationHandle& versioned_handle, + const HloModuleConfig& module_config) const { + // The computation shape is represented entirely by its ProgramShape member, + // so just serialize the proto as part of the key. + return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", + versioned_handle.version, "::", + module_config.compilation_cache_key()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h new file mode 100644 index 0000000000..09989726ae --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// A cache which stores Executables indexed by computation handle and version. +class CompilationCache { + public: + CompilationCache() {} + + // Insert the given Executable into the cache. Return a bare Executable + // pointer for the caller to use. Note: the returned pointer will *not* be the + // same as the given unique pointer if the computation already exists in the + // cache. See comments in the .cc implementation for details of this case. + // + // module_config is provided by the caller, instead of being taken from the + // executable, so that we can insert keys into the compilation cache that are + // devoid of layout (where XLA gets to choose what layout to compile). + // + // A shared_ptr is returned so the caller can keep the Executable from being + // destructed in the event that the Executable is evicted from the + // computation cache (and the cache's shared_ptr to the Executable is + // destructed). + std::shared_ptr Insert(std::unique_ptr executable, + const HloModuleConfig& module_config); + + // Lookup the Executable for the specified versioned computation in the cache. + // Return a shared_ptr to the Executable if it exists in the cache. Return + // nullptr otherwise. + std::shared_ptr LookUp( + const VersionedComputationHandle& versioned_handle, + const HloModuleConfig& module_config) const; + + protected: + mutable tensorflow::mutex mutex_; + + // Map from versioned handle with program layout to Executable built + // for that computation version and program layout. + using CacheKey = string; + + CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, + const HloModuleConfig& module_config) const; + std::map> cache_ GUARDED_BY(mutex_); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc new file mode 100644 index 0000000000..f71b2b6b9c --- /dev/null +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compiler.h" + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_; + +/* static */ void Compiler::LazyInitMutex() { + static std::once_flag mutex_init_flag; + std::call_once(mutex_init_flag, []() { + Compiler::platform_compiler_mutex_ = new tensorflow::mutex; + }); +} + +/* static */ std::map* +Compiler::GetPlatformCompilerFactories() { + static auto* r = + new std::map; + return r; +} + +/* static */ +std::map>* +Compiler::GetPlatformCompilers() { + static auto* r = new std::map>; + return r; +} + +/* static */ void Compiler::RegisterCompilerFactory( + se::Platform::Id platform_id, + std::function()> compiler_factory) { + LazyInitMutex(); + tensorflow::mutex_lock lock(*platform_compiler_mutex_); + auto* factories = GetPlatformCompilerFactories(); + CHECK(factories->find(platform_id) == factories->end()); + (*factories)[platform_id] = std::move(compiler_factory); +} + +/* static */ StatusOr Compiler::GetForPlatform( + const se::Platform* platform) { + LazyInitMutex(); + tensorflow::mutex_lock lock(*platform_compiler_mutex_); + + auto* compilers = GetPlatformCompilers(); + // See if we already instantiated a compiler for this platform. + { + auto it = compilers->find(platform->id()); + if (it != compilers->end()) { + return it->second.get(); + } + + // If not, we just fall through to try to create one with a registered + // factory. + } + + auto* factories = GetPlatformCompilerFactories(); + auto it = factories->find(platform->id()); + if (it == factories->end()) { + return NotFound( + "could not find registered compiler for platform %s -- check " + "target linkage", + platform->Name().c_str()); + } + + // And then we invoke the factory, placing the result into the mapping. + compilers->insert(std::make_pair(platform->id(), it->second())); + return compilers->at(platform->id()).get(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h new file mode 100644 index 0000000000..632081a747 --- /dev/null +++ b/tensorflow/compiler/xla/service/compiler.h @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The compiler API is used by the XLA service to generate executables that +// run on a given platform. This is a registry and abstract interface, for +// pluggability by the various platforms. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// The following types are used for ahead of time compilation. + +// Contains the object file data created as a result of ahead-of-time +// compuation. +using ObjectFileData = std::vector; + +// Contains the buffer sizes information needed to allocate buffers to execute +// an ahead-of-time computation. Entries which contain -1 designate a parameter +// which should be skipped over during allocation. +using BufferSizes = std::vector; + +// Abstract superclass describing the result of an ahead-of-time compilation. +class AotCompilationResult { + public: + AotCompilationResult(const AotCompilationResult&) = delete; + AotCompilationResult& operator=(AotCompilationResult const&) = delete; + + virtual ~AotCompilationResult() = default; + + protected: + AotCompilationResult() = default; +}; + +// Abstract superclass describing options to an ahead-of-time compilation. +class AotCompilationOptions { + public: + AotCompilationOptions(const AotCompilationOptions&) = delete; + AotCompilationOptions& operator=(AotCompilationOptions const&) = delete; + + virtual ~AotCompilationOptions() = default; + + // Returns the ID of the platform to which these options apply. + virtual perftools::gputools::Platform::Id PlatformId() const = 0; + + protected: + AotCompilationOptions() = default; +}; + +// Abstract compiler interface that is subclassed for compilation on a +// particular platform. +// +// The compiler ties together high level optimization (HLO) and low level +// optimization (LLO) / codegen (CG) to generate efficient executables for the +// target platform. +// +// The platform-based compiler singletons are registered via module initializers +// in their corresponding XLA compiler libraries, and are registered via the +// RegisterCompilerFactory API below. +// +// Thread-safety: subclasses of Compiler must be thread-safe, as multiple +// XLA clients may be requesting compilation concurrently for a given +// platform. +class Compiler { + public: + // Callback signature used to dump the HLO graph during compilation. + // Different compiler backends will call this as they please, providing + // a view of the HLO at different points in compilation -- context for the + // dump is indicated by the label string. + using HloDumper = + std::function; + + virtual ~Compiler() {} + + // Returns the ID of the platform that this compiler targets. + virtual perftools::gputools::Platform::Id PlatformId() const = 0; + + // Compiles the HLO module for execution on a device given by the executor, + // and returns an executable object or an error status. Takes ownership of the + // HLO module and is free to transform it. + // + // The compiler may optionally specialize to the individual device + // (not just type of device) indicated by the executor. + // + // TODO(leary) will need to update this API when a single computation can run + // across multiple devices simultaneously. + virtual StatusOr> Compile( + std::unique_ptr module, + std::unique_ptr module_config, HloDumper dump_hlo, + perftools::gputools::StreamExecutor* executor) = 0; + + // Compiles a set of HLO modules that can run in parallel, potentially + // communicating data between the modules, and returns a corresponding + // sequence of executable objects. + virtual StatusOr>> Compile( + std::vector> hlo_module, + std::vector> module_config, + HloDumper dump_hlo, + std::vector stream_exec) = 0; + + // Compiles the HLO module for ahead-of-time execution. This is intended for + // use in static compilation. + virtual StatusOr> CompileAheadOfTime( + std::unique_ptr module, + std::unique_ptr module_config, HloDumper dump_hlo, + const AotCompilationOptions& options) = 0; + + ///// + // The Compiler class also serves as a point to register compiler objects + // for the various platforms. + + using CompilerFactory = std::function()>; + + // Registers the compiler singleton for the platform. This is assumed to + // be a singleton, so no ownership is transferred. + // + // Precondition: a platform kind must not be registered more than once. + static void RegisterCompilerFactory( + perftools::gputools::Platform::Id platform_id, + CompilerFactory compiler_factory); + + // Returns the compiler singleton pointer if it is available for the given + // platform, or an error status if it is not. + static StatusOr GetForPlatform( + const perftools::gputools::Platform* platform); + + private: + // Mutex that guards the platform-compiler map. + static tensorflow::mutex* platform_compiler_mutex_; + static void LazyInitMutex(); + + // Map from platform kind to compiler factory. + static std::map* + GetPlatformCompilerFactories(); + + // Map from platform kind to compiler instance, if we made one already (based + // on the factories above). + static std::map>* + GetPlatformCompilers(); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc new file mode 100644 index 0000000000..d2d4f14fce --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/computation_layout.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +ComputationLayout::ComputationLayout(const ProgramShape& program_shape) + : result_layout_(program_shape.result()) { + for (auto& shape : program_shape.parameters()) { + parameter_layouts_.emplace_back(shape); + } + SetToDefaultLayout(); +} + +void ComputationLayout::SetToDefaultLayout() { + for (auto& parameter_layout : parameter_layouts_) { + parameter_layout.SetToDefaultLayout(); + } + result_layout_.SetToDefaultLayout(); +} + +bool ComputationLayout::LayoutIsSet() const { + return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(), + [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && + result_layout_.LayoutIsSet(); +} + +string ComputationLayout::ToString() const { + std::vector params; + for (auto& param_layout : parameter_layouts_) { + params.push_back(param_layout.ToString()); + } + return tensorflow::strings::StrCat("(", + tensorflow::str_util::Join(params, ", "), + ") => ", result_layout_.ToString()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h new file mode 100644 index 0000000000..80e102411c --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_LAYOUT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class which contains the layouts of the parameters and results of a +// computation. The layouts are stored as ShapeLayouts with immutable shapes and +// mutable layouts. +class ComputationLayout { + public: + // Constructs a ComputationLayout from a ProgramShape. The layouts of the + // parameters and results are set to the default layout. Layouts in the + // ProgramShape are ignored. + explicit ComputationLayout(const ProgramShape& program_shape); + + // Returns the layout of a particular parameter. + const ShapeLayout& parameter_layout(int64 param_no) const { + return parameter_layouts_[param_no]; + } + ShapeLayout* mutable_parameter_layout(int64 param_no) { + return ¶meter_layouts_[param_no]; + } + + // Returns the number of parameters in the computation. + int parameter_count() const { return parameter_layouts_.size(); } + + // Returns the ShapeLayouts of the parameters of the computation. + const std::vector& parameter_layouts() const { + return parameter_layouts_; + } + + // Returns the ShapeLayout of a result of the computation. + const ShapeLayout& result_layout() const { return result_layout_; } + ShapeLayout* mutable_result_layout() { return &result_layout_; } + + // Returns the shape of the particular parameter or result of the computation + // with layout. + const Shape& parameter_shape(int64 param_no) const { + return parameter_layouts_[param_no].shape(); + } + const Shape& result_shape() const { return result_layout_.shape(); } + + // Sets layouts of all parameters and the result to the default layout. + void SetToDefaultLayout(); + + // Returns true if all layouts (parameters and result) have been set. + bool LayoutIsSet() const; + + // Returns a string representation of this object. + string ToString() const; + + private: + std::vector parameter_layouts_; + ShapeLayout result_layout_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc new file mode 100644 index 0000000000..281277bed5 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_tracker.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/computation_tracker.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +ComputationTracker::ComputationTracker() : next_computation_(1) {} + +ComputationHandle ComputationTracker::NewComputation( + const string& computation_name) { + tensorflow::mutex_lock lock(computation_mutex_); + ComputationHandle computation_handle; + int64 handle_value = next_computation_++; + computation_handle.set_handle(handle_value); + opaque_to_computation_[handle_value] = + MakeUnique(computation_name, computation_handle); + return computation_handle; +} + +StatusOr ComputationTracker::LoadSessionModule( + const SessionModule& session_module) { + tensorflow::mutex_lock lock(computation_mutex_); + + // For each embedded computation, create a new computation based on its + // serialized data, and place the mapping from the old computation handle to + // the new computation handle. + std::map old_to_new; + for (const SessionComputation& computation : + session_module.embedded_computations()) { + const int64 old_handle = computation.computation_handle().handle(); + TF_ASSIGN_OR_RETURN(old_to_new[old_handle], + LoadSessionComputation(computation, &old_to_new)); + } + + // Finally, place the entry computation in the tracker with all of the + // remappings populated from the above. + const int64 old_handle = session_module.entry().computation_handle().handle(); + TF_ASSIGN_OR_RETURN( + old_to_new[old_handle], + LoadSessionComputation(session_module.entry(), &old_to_new)); + return old_to_new[old_handle]; +} + +StatusOr> +ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); + const VersionedComputationHandle entry_versioned_handle = + user_computation->GetVersionedHandle(); + std::set visited; + std::list post_order; + { + tensorflow::mutex_lock lock(computation_mutex_); + ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); + } + auto session_module = MakeUnique(); + *session_module->mutable_entry() = + Resolve(entry_versioned_handle.handle) + .ValueOrDie() + ->CloneSessionComputation(entry_versioned_handle.version); + for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { + *session_module->add_embedded_computations() = + Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); + } + return std::move(session_module); +} + +StatusOr ComputationTracker::Resolve( + const ComputationHandle& computation) const { + tensorflow::mutex_lock lock(computation_mutex_); + return ResolveInternal(computation); +} + +ComputationHandle ComputationTracker::AllocateHandle() { + int64 handle_value = next_computation_++; + ComputationHandle result; + result.set_handle(handle_value); + return result; +} + +StatusOr ComputationTracker::LoadSessionComputation( + const SessionComputation& session_computation, + std::map* old_to_new) { + TF_RET_CHECK(old_to_new != nullptr); + const ComputationHandle new_handle = AllocateHandle(); + (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; + TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], + UserComputation::MakeWithRemapping( + session_computation, new_handle, *old_to_new)); + return new_handle; +} + +StatusOr ComputationTracker::ResolveInternal( + const ComputationHandle& computation) const { + auto it = opaque_to_computation_.find(computation.handle()); + if (it == opaque_to_computation_.end()) { + return NotFound("computation handle not found: %lld", computation.handle()); + } + UserComputation* user_computation = it->second.get(); + return user_computation; +} + +void ComputationTracker::ComputeComputationPostOrder( + const VersionedComputationHandle& versioned_handle, + std::set* visited, + std::list* post_order) const { + if (visited->count(versioned_handle) > 0) { + DCHECK_EQ(1, visited->count(versioned_handle)); + return; + } + + UserComputation* computation = + ResolveInternal(versioned_handle.handle).ValueOrDie(); + std::vector embedded_handles = + computation->GetEmbeddedComputations(versioned_handle.version); + + for (const auto& embedded_handle : embedded_handles) { + ComputeComputationPostOrder(embedded_handle, visited, post_order); + } + + visited->insert(versioned_handle); + post_order->push_back(versioned_handle); + return; +} + +StatusOr> ComputationTracker::BuildHloModule( + const VersionedComputationHandle& entry_handle, + bool include_unused_parameters) const { + tensorflow::mutex_lock lock(computation_mutex_); + + TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, + ResolveInternal(entry_handle.handle)); + + // Build a topological sort of the entry and any embedded computations as a + // list. The root of the computation will be the last element in the list. + std::set visited; + std::list post_order; + ComputeComputationPostOrder(entry_handle, &visited, &post_order); + + // Map from ComputationHandle value and computation version to HloComputation. + std::map hlo_computations; + + // The resolver lambda resolves VersionedHandles to embedded + // HloComputation*. This is required by UserComputation::BuildHloComputation + // when lowering calling operations (map, reduce etc). + auto resolver = [&hlo_computations]( + const VersionedComputationHandle& versioned_handle) -> HloComputation* { + CHECK_GT(hlo_computations.count(versioned_handle), 0); + return hlo_computations.at(versioned_handle); + }; + + string module_name = + tensorflow::strings::StrCat(entry_computation->name(), "_module"); + auto module = MakeUnique(module_name, entry_handle); + for (auto versioned_handle : post_order) { + UserComputation* computation = + ResolveInternal(versioned_handle.handle).ValueOrDie(); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_computation, + computation->BuildHloComputation(versioned_handle.version, resolver, + include_unused_parameters)); + + // Add the newly created computation to VersionedHandle-to-HloComputation + // map. + DCHECK_EQ(0, hlo_computations.count(versioned_handle)); + hlo_computations[versioned_handle] = hlo_computation.get(); + + if (computation == entry_computation) { + module->AddEntryComputation(std::move(hlo_computation)); + } else { + module->AddEmbeddedComputation(std::move(hlo_computation)); + } + } + + return std::move(module); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h new file mode 100644 index 0000000000..7d0660d7f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_tracker.h @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/user_computation.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Tracks computations for the XLA service; computations can be registered +// with a UserComputation instance and can be resolved from a handle for later +// use. +// +// This class is also capable of serializing/deserializing computations that it +// tracks (and to serialize properly you need to serialize all referred-to +// computations as well). +class ComputationTracker { + public: + ComputationTracker(); + + // Creates a new UserComputation object and returns the corresponding + // ComputationHandle for it. + // + // Precondition: user_computation is not already present in the map. + ComputationHandle NewComputation(const string& computation_name); + + // Restores session data for a computation that has been serialized, and + // allocates a new computation handle for it. + StatusOr LoadSessionModule( + const SessionModule& session_module); + + // Snapshots a computation (referenced by the provided handle) at its latest + // version, returning a module where it is the entry, and any referred-to + // computations are entrained as "embedded" (non-entry) computations. + StatusOr> SnapshotComputation( + const ComputationHandle& computation); + + // Resolves a ComputationHandle to a UserComputation that is present in the + // map. + StatusOr Resolve( + const ComputationHandle& computation) const; + + // Builds an HLO module using the specified computation as the entry. The + // module will include the entry computation as well as all computations which + // are called directly or indirectly from the entry computation via operations + // like "map". If include_unused_parameters is true, then all parameters are + // lowered to HLO instructions even if they are not used. This ensures the + // entry HloComputation has the same program shape (ProgramShape) as the entry + // UserComputation. + StatusOr> BuildHloModule( + const VersionedComputationHandle& entry_handle, + bool include_unused_parameters = true) const; + + private: + // Bumps the next_computation_ number and returns the allocated number wrapped + // in a ComputationHandle. + ComputationHandle AllocateHandle() + EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); + + // Loads a session computation into a UserComputation, registers it, and + // returns the computation handle of the registered computation. If old_to_new + // is provided, it is used for remapping references to computations present in + // session_computation. + // + // old_to_new will be updated with the mapping from session_computation's old + // handle to the returned handle value, and may not be null. + StatusOr LoadSessionComputation( + const SessionComputation& session_computation, + std::map* old_to_new) + EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); + + // Internal implementation of Resolve method which requires, but does not + // acquire the mutex. + StatusOr ResolveInternal( + const ComputationHandle& computation) const + EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); + + // Builds a post order sort of a computation ("entry") and all of its embedded + // computations including all transitively embedded computations. An embedded + // computation (the callee) will always appear in the sort before the + // computation which calls the embedded computation (the caller). Necessarily, + // the entry computation is the last element in the sort. visited and + // post_order should be empty when calling. post_order contains the post order + // sort when the function return. + void ComputeComputationPostOrder( + const VersionedComputationHandle& versioned_handle, + std::set* visited, + std::list* post_order) const + EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); + + // Guards the computation mapping. Marked mutable so that the Resolve method + // can remain const; Resolve does't really modify the tracker in any way, but + // it has to lock the mutex for safety. + mutable tensorflow::mutex computation_mutex_; + + // The next sequence number to assign to a computation, guarded by the same + // mutex as the mapping as they'll be mutated at the same time. + int64 next_computation_ GUARDED_BY(computation_mutex_); + + // Mapping from ComputationHandle value to the corresponding registered + // UserComputation object. + std::map> opaque_to_computation_ + GUARDED_BY(computation_mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc new file mode 100644 index 0000000000..dbf5085c1e --- /dev/null +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -0,0 +1,439 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// InstructionCopier encapsulates indices at which to copy 'instruction'. +// All 'instruction' users in 'copy_users' are updated to use the copy. +// +// Instruction copies are generated in two phases: +// 1) Recording buffer indices at which 'instruction' requires copies (i.e. +// setting 'indices_to_copy_[index]'=true). +// 2) Inserting kCopy instructions based on indices recorded in phase 1). +// *) Array instructions are copied by inserting a single kCopy instruction. +// *) Tuple-shaped instructions are copied by recursively expanding tuples +// (and tuple-shaped elements), and inserting kCopy instructions for any +// tuple elements which require a copy. As the recursion unwinds, new tuple +// instructions are added to gather the copied (and uncopied) references +// into the output tuple (i.e. the copy of the tuple-shaped instruction). +// +// Example two-element tuple with one element that needs a copy: +// +// Tuple // instruction +// / \ +// GTE(0) GTE(1) +// | | +// Copy | +// \ / +// Tuple // copied-instruction +// +class InstructionCopier { + public: + InstructionCopier(const bool init_value, HloInstruction* instruction, + const std::vector& copy_users); + + // Returns true if all recorded indices are false (returns true otherwise). + bool HasAllIndicesFalse() const; + + // Records instruction buffer indices which point-to a Parameter or Constant. + tensorflow::Status RecordIndicesWhichPointToParamOrConstant( + const TuplePointsToAnalysis& points_to_analysis); + + // Records instruction buffer indices to copy which are necessary to ensure: + // *) PointsToSet of 'instruction_' is unambiguous and distinct. + // *) No liveness interference between 'instruction_' and 'other_instruction'. + tensorflow::Status RecordIndicesToCopyForColocatingBuffers( + BufferLiveness* liveness, HloInstruction* other_instruction); + + // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', + // and replaces all uses for instructions in 'copy_users_' with copy. + // Returns the instruction which is a copy 'instruction'. + HloInstruction* Copy(); + + HloInstruction* instruction() { return instruction_; } + + const std::vector& copy_users() const { return copy_users_; } + + private: + // Records instruction buffer indices which have ambiguous or non-distinct + // points-to sets. + tensorflow::Status RecordAmbiguousOrNonDistinctIndices( + const TuplePointsToAnalysis& points_to_analysis); + + // Records instruction buffer indices which have interferring live ranges + // with 'other_instruction' buffers at same index. + tensorflow::Status RecordIndicesWhichInterfereWithOtherInstruction( + BufferLiveness* liveness, HloInstruction* other_instruction); + + // Recursively inserts copies of 'instruction' tuple elements at indices + // specified in 'indices_to_copy', and returns the copy of 'instruction'. + HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index); + + void RecordIndex(const ShapeIndex& index) { + *indices_to_copy_.mutable_element(index) = true; + } + + HloInstruction* instruction_; + std::vector copy_users_; + ShapeTree indices_to_copy_; +}; + +InstructionCopier::InstructionCopier( + const bool init_value, HloInstruction* instruction, + const std::vector& copy_users) + : instruction_(instruction), + copy_users_(copy_users), + indices_to_copy_(instruction->shape(), init_value) {} + +bool InstructionCopier::HasAllIndicesFalse() const { + bool all_indices_false = true; + TF_CHECK_OK(indices_to_copy_.ForEachElement([&all_indices_false]( + const ShapeIndex& /*index*/, bool /*is_leaf*/, const bool& data) { + if (data) all_indices_false = false; + return tensorflow::Status::OK(); + })); + return all_indices_false; +} + +tensorflow::Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( + const TuplePointsToAnalysis& points_to_analysis) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction_); + // Shallow copy the instruction if the points-to set of the top-level + // buffer is ambiguous. This is necessary because the backends must know + // statically what the top-level buffer of the result is. + if (points_to.element(/*index=*/{}).size() > 1) { + RecordIndex({}); + } + + // Multiple buffers within a parameter/constant may be live out, so collect + // a set of indices at which to copy first. + TF_RETURN_IF_ERROR(points_to.ForEachElement([this]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { + for (auto buffer : buffers) { + // pointee is the HloInstruction producing the buffer which may be + // liveout. + HloInstruction* pointee = buffer->instruction(); + if (pointee->opcode() == HloOpcode::kParameter || + pointee->opcode() == HloOpcode::kConstant) { + VLOG(2) << "Parameter or constant buffer " << buffer->ToString() + << " index: " << tensorflow::str_util::Join(index, ",") + << " may be live out of computation: " << pointee->ToString(); + RecordIndex(index); + } + } + return tensorflow::Status::OK(); + })); + return tensorflow::Status::OK(); +} + +tensorflow::Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( + BufferLiveness* liveness, HloInstruction* other_instruction) { + TF_RETURN_IF_ERROR( + RecordAmbiguousOrNonDistinctIndices(liveness->points_to_analysis())); + TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( + liveness, other_instruction)); + return tensorflow::Status::OK(); +} + +tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( + const TuplePointsToAnalysis& points_to_analysis) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction_); + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + // TODO(b/32116879) User ShapeIndex here when it is available. + std::unordered_map> + buffer_to_source_indices; + TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { + if (buffers.size() > 1) { + // Record ambiguous points-to set at 'index'. + if (!indices_to_copy_.element(index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " with ambiguous points-to set."; + RecordIndex(index); + } + } + // For each 'buffer': record a mapping from 'buffer' to 'index'. + for (auto& buffer : buffers) { + auto it = buffer_to_source_indices.find(buffer); + if (it == buffer_to_source_indices.end()) { + buffer_to_source_indices.insert({buffer, std::vector()}); + } + buffer_to_source_indices[buffer].push_back(index); + } + return tensorflow::Status::OK(); + })); + + // Record all non-distinct indices detected in 'buffer_to_source_indices'. + for (auto& buff_to_src : buffer_to_source_indices) { + if (buff_to_src.second.size() == 1) { + continue; + } + for (auto& src_index : buff_to_src.second) { + // Record non-distinct points-to set at 'src_index'. + if (!indices_to_copy_.element(src_index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(src_index, ",") + << " because of non-distinct points-to set."; + RecordIndex(src_index); + } + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status +InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( + BufferLiveness* liveness, HloInstruction* other_instruction) { + // Record all buffer indices for 'instruction_', which interfere with + // 'other_instruction' at the same index. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( + instruction_->shape(), + [this, &liveness, &other_instruction](const Shape& /*subshape*/, + const ShapeIndex& index) { + if (indices_to_copy_.element(index)) { + // Return if previous pass already set index. + return tensorflow::Status::OK(); + } + auto& points_to_analysis = liveness->points_to_analysis(); + // Lookup buffers for 'instruction_' and 'other_instruction'. + const std::vector instruction_buffers = + points_to_analysis.GetPointsToSet(instruction_).element(index); + // If 'instruction_' has ambiguous points-to-set at 'index', it would + // have been recorded in a previous pass (and we would have returned + // early at the entry to this function). As a result, here we know that + // 'instruction_' has just one buffer in its points-to-set. + CHECK_EQ(1, instruction_buffers.size()); + const LogicalBuffer* instruction_buffer = instruction_buffers[0]; + + const std::vector other_instruction_buffers = + points_to_analysis.GetPointsToSet(other_instruction).element(index); + // Do not insert a copy if both instructions point at the same buffer. + // This eliminates unnecessary copies of read-only tuple elements. + // If 'instruction_' and 'other_instruction' point to the same buffer, + // then that buffer is not updated on the path between the two + // instructions. Therefore, any other (possibly interference-causing) + // users of that buffer from 'other_instruction' will see the same data, + // irrespecive of whether we insert a copy of this buffer at + // 'instruction_' or not. + if (other_instruction_buffers.size() == 1 && + other_instruction_buffers[0]->id() == instruction_buffer->id()) { + return tensorflow::Status::OK(); + } + // We cant say anything about the ambiguity of 'other_instruction' at + // this point, so we need to check interference between the single + // buffer in the points-to set of 'instruction_' and all buffers in + // 'other_instruction_buffers'. + for (auto& other_buffer : other_instruction_buffers) { + if (liveness->MayInterfere(*instruction_buffer, *other_buffer)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " because of interference with buffer: " + << other_buffer->ToString(); + RecordIndex(index); + break; + } + } + return tensorflow::Status::OK(); + })); + return tensorflow::Status::OK(); +} + +// Recursively inserts copies of 'instruction' tuple element buffers at +// indices in 'indices_to_copy_', expanding tuples as needed. +// TODO(b/31159897) Remove superfluous Tuple->GTE->Tuple expressions. +HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, + ShapeIndex* index) { + std::vector element_copies; + const int64 num_tuple_elements = + ShapeUtil::TupleElementCount(instruction->shape()); + for (int64 i = 0; i < num_tuple_elements; ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i)); + HloInstruction* element_copy; + index->push_back(i); + if (ShapeUtil::IsTuple(gte->shape())) { + element_copy = CopyTuple(gte, index); + } else { + if (indices_to_copy_.element(*index)) { + element_copy = gte->parent()->AddInstruction( + HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte)); + } else { + element_copy = gte; + } + } + index->pop_back(); + element_copies.push_back(element_copy); + } + return instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); +} + +// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. +HloInstruction* InstructionCopier::Copy() { + ShapeIndex index; + HloInstruction* copy; + if (ShapeUtil::IsTuple(instruction_->shape())) { + copy = CopyTuple(instruction_, &index); + } else { + copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction_->shape(), HloOpcode::kCopy, instruction_)); + } + for (HloInstruction* user : copy_users_) { + VLOG(2) << "Adding copy between instruction: " << instruction_->name() + << " and user: " << user->name(); + instruction_->ReplaceUseWith(user, copy); + } + return copy; +} + +} // anonymous namespace + +StatusOr CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { + auto copy_it = inserted_copies_.find(hlo); + if (copy_it == inserted_copies_.end()) { + HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); + inserted_copies_.insert({hlo, copy}); + return copy; + } else { + return copy_it->second; + } +} + +StatusOr CopyInsertion::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "CopyInsertion for module " << module->name(); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr liveness, + BufferLiveness::Run(module, MakeUnique(module))); + auto& points_to_analysis = liveness->points_to_analysis(); + XLA_VLOG_LINES(2, points_to_analysis.ToString()); + XLA_VLOG_LINES(2, module->ToString()); + + // Gather references to all while body computations in 'module'. + std::unordered_set while_body_computations; + // Gather references to all while instructions in 'module' by computation. + std::unordered_map> + while_instructions; + for (auto& computation : module->computations()) { + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + while_body_computations.insert(instruction->while_body()); + auto it = while_instructions.find(computation.get()); + if (it == while_instructions.end()) { + while_instructions.insert( + {computation.get(), std::vector()}); + } + while_instructions[computation.get()].emplace_back(instruction.get()); + } + } + + for (auto& computation : module->computations()) { + VLOG(2) << "computation " << computation->name(); + + // Collect instruction buffer indices to copy in 'instructions_to_copy'. + std::vector instructions_to_copy; + + // Add copies of while 'init' operand instructions (if needed). + // TODO(b/33301720) Remove redundant while instruction copies. + auto it = while_instructions.find(computation.get()); + if (it != while_instructions.end()) { + for (auto& while_hlo : it->second) { + // Create InstructionCopier for init operand of while instruction. + HloInstruction* init_hlo = while_hlo->mutable_operand(0); + instructions_to_copy.push_back( + InstructionCopier(/*init_value=*/false, init_hlo, {while_hlo})); + InstructionCopier& init_copier = instructions_to_copy.back(); + // Record 'init' buffer indices which point-to a Constant or Parameter. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( + liveness->points_to_analysis())); + // Record indices necessary to colocate while and init operand buffers. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( + liveness.get(), while_hlo)); + } + } + + // Create InstructionCopier for computation root instruction. + instructions_to_copy.push_back(InstructionCopier( + /*init_value=*/false, computation->root_instruction(), {})); + InstructionCopier& root_copier = instructions_to_copy.back(); + + if (while_body_computations.count(computation.get()) > 0) { + // Record root indices to copy for while body sub-computations. + // We do not need to call RecordIndicesWhichPointToParamOrConstant for + // the while root instruction here, because any neccessary copies needed + // to avoid constant or parameters in the output are handled by while.init + // operand copy insertion above (which will share an allocation). + TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( + liveness.get(), computation->parameter_instruction(0))); + } else { + // Record root indices to copy for general computations. + TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( + liveness->points_to_analysis())); + } + + for (auto& to_copy : instructions_to_copy) { + if (to_copy.HasAllIndicesFalse()) { + continue; + } + changed = true; + + // Copy instruction at recorded buffer indices. + HloInstruction* copy = to_copy.Copy(); + if (to_copy.instruction() == computation->root_instruction()) { + computation->set_root_instruction(copy); + } + } + } + + VLOG(3) << "After copy insertion for module " << module->name(); + XLA_VLOG_LINES(3, module->ToString()); + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h new file mode 100644 index 0000000000..4ea393ba94 --- /dev/null +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// HLO pass which inserts a copy of the root instruction (creating a new root) +// if the root is or points-to any constant or parameter instruction. +// If the root instruction is a Tuple, only tuple elements which point to +// constant or parameter instructions will be copied. +// Copy insertion is necessary because constant and parameter arrays have +// different lifetimes than computation results. +class CopyInsertion : public HloPass { + public: + CopyInsertion() : HloPass("copy-insertion") {} + ~CopyInsertion() override {} + + // Run the pass on the given module. Returns whether the module was changed + // (copies were inserted). + StatusOr Run(HloModule* module) override; + + protected: + // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // duplicate copies. + StatusOr FindOrInsertCopy(HloInstruction* hlo); + + // A map containing all copies inserted during the copy insertion pass. The + // key is the copied instruction and the value is the copy. + std::unordered_map inserted_copies_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc new file mode 100644 index 0000000000..e64da58dc7 --- /dev/null +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -0,0 +1,1153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/copy_insertion.h" + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +#include "tensorflow/compiler/xla/test_helpers.h" + +namespace xla { +namespace { + +class CopyInsertionTest : public HloTestBase { + protected: + void InsertCopies(HloModule* module) { + CopyInsertion copy_insertion; + EXPECT_IS_OK(copy_insertion.Run(module).status()); + + // Verify the points to set of the root of the computation after copy + // insertion contains no constants or parameters. + auto points_to_analysis = + TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); + const std::set maybe_live_out_buffers = + points_to_analysis + ->GetPointsToSet(module->entry_computation()->root_instruction()) + .CreateFlattenedSet(); + for (const LogicalBuffer* buffer : maybe_live_out_buffers) { + EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); + EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); + } + } + + // OperandTree is a test helper class that simplifies the expression of + // an expected tree of operands (starting at some root instruction) in a + // unit test. + // Each HLO instruction is represented as a node in the OperandTree. + struct OperandTree { + // The expected opcode for this OperandTree node. + HloOpcode opcode; + // The set of operands expected for this OperandTree node. + std::vector operands; + // If non-null, a pointer to the expected HloInstruction at this node. + const HloInstruction* instruction = nullptr; + + // Returns a mutable reference to operand 'i' of this node. + OperandTree& op(int i) { + if (i >= operands.size()) { + operands.resize(i + 1); + } + return operands[i]; + } + + // Check that 'instruction' and its operands match expected values recorded + // in OperandTree. + void Check(const HloInstruction* instruction) { + EXPECT_EQ(opcode, instruction->opcode()); + if (instruction != nullptr) { + EXPECT_EQ(instruction, instruction); + } + if (operands.empty()) { + return; + } + EXPECT_EQ(operands.size(), instruction->operand_count()); + for (int i = 0; i < instruction->operand_count(); ++i) { + operands[i].Check(instruction->operand(i)); + } + } + }; +}; + +#define EXPECT_INST(A, E...) EXPECT_EQ(A, (std::set{E})) + +TEST_F(CopyInsertionTest, SingleParameter) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({x})); + + EXPECT_INST(x->users(), tuple); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check path from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, SingleConstant) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant})); + + EXPECT_INST(constant->users(), tuple); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check path from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { + // Create a computation with more than one constant and parameter. Only one of + // each constant/parameter is pointed to by the output tuple. Only these + // instructions should be copied. + auto builder = HloComputation::Builder(TestName()); + + HloInstruction* constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y")); + + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y)); + + builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // "constant2" and parameter "x" are pointed to by the tuple and should be + // copied. + + // Check all paths from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).instruction = old_root; + + op_tree.op(2).opcode = HloOpcode::kGetTupleElement; + op_tree.op(2).op(0).opcode = HloOpcode::kTuple; + op_tree.op(2).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { + // Create a computation using select which has an ambiguous points-to set for + // the computation result. Verify that copies are added properly. + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction* constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + + HloInstruction* tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + HloInstruction* tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant3, constant2})); + + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + + EXPECT_INST(constant1->users(), tuple1); + EXPECT_INST(constant2->users(), tuple1, tuple2); + EXPECT_INST(constant3->users(), tuple2); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check all paths from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kSelect; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kSelect; + op_tree.op(1).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, BitcastParameter) { + // The output of a bitcast is its operand (same buffer), so a bitcast + // parameter feeding the result must have a copy added. + auto builder = HloComputation::Builder(TestName()); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); + HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_INST(x->users(), bitcast); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check path from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kCopy; + op_tree.op(0).opcode = HloOpcode::kBitcast; + op_tree.op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, BitcastConstant) { + // The output of a bitcast is its operand (same buffer), so a bitcast + // constant feeding the result must have a copy added. + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 42.0}))); + HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_INST(constant->users(), bitcast); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check path from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kCopy; + op_tree.op(0).opcode = HloOpcode::kBitcast; + op_tree.op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { + // Same as BitcastParameter, but the bitcast is wrapped in a tuple. + auto builder = HloComputation::Builder(TestName()); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); + HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); + builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(1, x->user_count()); + EXPECT_EQ(*x->users().begin(), bitcast); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check path from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, NestedTupleParameter) { + // Construct a trivial computation where the root of the computation is a + // nested tuple-shaped parameter. The parameter should be deep copied and the + // copy should be the root of the computation. + auto builder = HloComputation::Builder(TestName()); + + // Param shape is: ((F32[], S32[1,2,3]), F32[42]) + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {1, 2, 3})}), + ShapeUtil::MakeShape(F32, {42})}), + "param0")); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(HloOpcode::kParameter, + module.entry_computation()->root_instruction()->opcode()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + EXPECT_NE(old_root, new_root); + + // Check all paths from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).op(0).op(0).opcode = HloOpcode::kParameter; + op_tree.op(0).op(0).op(0).op(0).op(0).instruction = old_root; + + op_tree.op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(1).opcode = HloOpcode::kCopy; + op_tree.op(0).op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(1).op(0).op(0).op(0).opcode = HloOpcode::kParameter; + op_tree.op(0).op(1).op(0).op(0).op(0).instruction = old_root; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kParameter; + op_tree.op(1).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { + // Construct a computation where the root of the computation is a tuple + // element of a nested tuple-shaped parameter. + auto builder = HloComputation::Builder(TestName()); + + // Param shape is: ((F32[], S32[1,2,3]), F32[42]) + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {1, 2, 3})}), + ShapeUtil::MakeShape(F32, {42})}), + "param0")); + + // The return value of the computation is the zero-th elemnt of the nested + // tuple. This element is itself a tuple. + auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gte, module.entry_computation()->root_instruction()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check all paths from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { + // Create a computation using select which has an ambiguous points-to set for + // the top-level buffer of the root of the computation. Verify that a shallow + // copy is added. + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + + HloInstruction* tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + HloInstruction* tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant2, constant1})); + + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + HloInstruction* gte = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gte, module.entry_computation()->root_instruction()); + + HloInstruction* old_root = module.entry_computation()->root_instruction(); + InsertCopies(&module); + HloInstruction* new_root = module.entry_computation()->root_instruction(); + + // Check path from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kCopy; + op_tree.op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +class WhileCopyInsertionTest : public CopyInsertionTest { + protected: + WhileCopyInsertionTest() : module_(TestName()) {} + + // Builds a While condition computation which reads the induction variable + // from the tuple parameter, and returns a predicate indicating whether this + // value is less than the constant '10'. + // The parameter 'nested' specifies the loop state shape from which to + // read the induction variable. + std::unique_ptr BuildConditionComputation( + bool nested = false) { + auto builder = HloComputation::Builder(TestName() + ".Condition"); + auto limit_const = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + const Shape& loop_state_shape = + nested ? nested_loop_state_shape_ : loop_state_shape_; + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + limit_const->shape(), loop_state, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, + induction_variable, limit_const)); + return builder.Build(); + } + + // Builds a While body computation with one output tuple element dependent on + // both input tuple elements. + // EX: + // Body({in0, in1}) + // out0 = Add(in0, 1) + // out1 = Add(BCast(in0), in1) + // Tuple(out0, out1) + std::unique_ptr BuildDependentBodyComputation() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + // Create param instruction to access loop state. + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + // Update the induction variable GTE(0). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); + // Update data GTE(1). + auto data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + // Use 'induction_variable' in computation with no path to output tuple. + auto update = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kAdd, data, update)); + // Create output Tuple. + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + return builder.Build(); + } + + // Builds a While body computation with read-only tuple element 0. + // both input tuple elements. + // EX: + // Body({in0, in1}) + // out0 = in0 + // out1 = Add(BCast(in0), in1) + // Tuple(out0, out1) + std::unique_ptr BuildDependentBodyOneReadOnlyComputation() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + // Create param instruction to access loop state. + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + // Update the induction variable GTE(0). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + // Update data GTE(1). + auto data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + // Use 'induction_variable' in computation with no path to output tuple. + auto update = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kAdd, data, update)); + // Create output Tuple. + builder.AddInstruction( + HloInstruction::CreateTuple({induction_variable, add1})); + return builder.Build(); + } + + // Builds a While body computation with independent outputs. + // EX: + // Body({in0, in1}) + // out0 = Add(in0, 1) + // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) + // Tuple(out0, out1) + std::unique_ptr BuildIndependentBodyComputation() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + // Create param instruction to access loop state. + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + // Update the induction variable GTE(0). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + // add0 = Add(in0, 1) + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); + // Update data GTE(1). + auto data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + auto update = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + // add0 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kAdd, data, update)); + // Create output Tuple. + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + return builder.Build(); + } + + // Builds a While body computation with the following nested tuple + // sub-computation: + // | + // GTE(loop_state, 1) + // / \ + // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) + // | | + // Add Reverse + // | | + std::unique_ptr BuildNestedBodyComputation() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + // Create param instruction to access loop state. + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, nested_loop_state_shape_, "loop_state")); + // Update GTE(0). + auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + gte0->shape(), HloOpcode::kAdd, gte0, inc)); + + // GTE(loop_state, 1) + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + nested_tuple_shape_, loop_state, 1)); + // GTE(GTE(loop_state, 1), 0) -> Add + auto gte10 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); + auto update10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kAdd, gte10, update10)); + + // GTE(GTE(loop_state, 1), 1) -> Reverse + auto gte11 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1)); + auto rev11 = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape_, gte11, {0})); + + // Create output Tuple. + auto inner_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11})); + builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple})); + return builder.Build(); + } + + // Builds a While instruction using 'condition' and 'body' sub-computations. + // Init operand is initialized to zeros of appropriate shape. + void BuildWhileInstruction(HloComputation* condition, HloComputation* body, + bool nested = false) { + auto builder = HloComputation::Builder(TestName() + ".While"); + auto induction_var_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + + if (nested) { + auto inner_init = builder.AddInstruction( + HloInstruction::CreateTuple({data_init, data_init})); + auto loop_state_init = builder.AddInstruction( + HloInstruction::CreateTuple({induction_var_init, inner_init})); + builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition, body, loop_state_init)); + module_.AddEntryComputation(builder.Build()); + return; + } + + auto loop_state_init = builder.AddInstruction( + HloInstruction::CreateTuple({induction_var_init, data_init})); + builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition, body, loop_state_init)); + module_.AddEntryComputation(builder.Build()); + } + + HloInstruction* BuildWhileInstruction_InitPointsToConstant() { + auto builder = HloComputation::Builder(TestName() + ".While"); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, + &builder); + } + + HloInstruction* BuildWhileInstruction_InitPointsToParameter() { + auto builder = HloComputation::Builder(TestName() + ".While"); + auto data_init = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "data_init")); + return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, + &builder); + } + + HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() { + auto builder = HloComputation::Builder(TestName() + ".While"); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto v1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, one, {1})); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto v2 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2})); + auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( + nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); + + return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, + data_init, &builder); + } + + HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() { + auto builder = HloComputation::Builder(TestName() + ".While"); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto one_vec = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, one, {1})); + auto data_init = + builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec})); + + return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, + data_init, &builder); + } + + HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { + auto builder = HloComputation::Builder(TestName() + ".While"); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, one, {1})); + auto one_vec = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + // Take a reference to 'data_init' to make it interfere with while result. + builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kAdd, data_init, one_vec)); + + return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, + &builder); + } + + HloInstruction* BuildWhileInstructionWithCustomInit( + const Shape& loop_state_shape, HloInstruction* data_init, + HloComputation::Builder* builder) { + auto induction_var_init = builder->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation()); + auto body = + module_.AddEmbeddedComputation(BuildIndependentBodyComputation()); + auto loop_state_init = builder->AddInstruction( + HloInstruction::CreateTuple({induction_var_init, data_init})); + auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition, body, loop_state_init)); + module_.AddEntryComputation(builder->Build()); + return while_hlo; + } + + HloModule module_; + Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {}); + Shape data_shape_ = ShapeUtil::MakeShape(F32, {8}); + Shape loop_state_shape_ = + ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_}); + Shape nested_tuple_shape_ = + ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); + Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, nested_tuple_shape_}); + Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {}); +}; + +// Tests while body computation with independent tuple elements: +// +// While.Body({in0, in1}) +// out0 = Add(in0, 1) +// out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) +// Tuple(out0, out1) +// +// CopyInsertion pass should not generate any copies. +// +TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { + auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto body = module_.AddEmbeddedComputation(BuildIndependentBodyComputation()); + BuildWhileInstruction(condition, body); + + HloInstruction* old_root = body->root_instruction(); + InsertCopies(&module_); + HloInstruction* new_root = body->root_instruction(); + + // No copies should be inserted so root should not be updated. + CHECK_EQ(old_root, new_root); +} + +// Tests while body computation with dependent tuple elements: +// +// While.Body({in0, in1}) +// out0 = Add(in0, 1) +// out1 = Add(BCast(in0), in1) +// Tuple(out0, out1) +// +// CopyInsertion pass should generate: +// +// Tuple // old root +// / \ +// GTE(0) GTE(1) +// | | +// Copy | +// \ / +// Tuple // new root +// +TEST_F(WhileCopyInsertionTest, DependentTupleElements) { + auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto body = module_.AddEmbeddedComputation(BuildDependentBodyComputation()); + BuildWhileInstruction(condition, body); + + HloInstruction* old_root = body->root_instruction(); + InsertCopies(&module_); + HloInstruction* new_root = body->root_instruction(); + + // Check all paths from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_root; + + op_tree.op(1).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +// Tests while body computation with read-only tuple element 0: +// +// PARAMETER +// / \ +// GTE(0) GTE(1) +// | \ | +// | BCAST | +// | \ | +// | ADD +// | | +// \ / +// TUPLE (root) +// +// CopyInsertion pass should not generate any copies. +// +TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { + auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto body = module_.AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + BuildWhileInstruction(condition, body); + + HloInstruction* old_root = body->root_instruction(); + InsertCopies(&module_); + HloInstruction* new_root = body->root_instruction(); + + // No copies should be inserted so root should not be updated. + CHECK_EQ(old_root, new_root); +} + +// Tests while body computation with nested tuple elements: +// +// | +// GTE(loop_state, 1) +// / \ +// GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) +// | | +// Add Reverse +// | | +// +// CopyInsertion pass should generate: +// +// Tuple // old root +// / \ +// / \ +// GTE(0) GTE(1) +// | / \ +// | / \ +// | GTE(0) GTE(1) +// | | | +// | | Copy +// | | | +// \ | / +// \ Tuple // "inner" tuple. +// \ / +// \ / +// Tuple // new root +// +TEST_F(WhileCopyInsertionTest, NestedTupleElements) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(true)); + auto body = module_.AddEmbeddedComputation(BuildNestedBodyComputation()); + BuildWhileInstruction(condition, body, true); + + HloInstruction* old_root = body->root_instruction(); + InsertCopies(&module_); + HloInstruction* new_root = body->root_instruction(); + + // Check all paths from 'new_root' to 'old_root'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).instruction = old_root; + + op_tree.op(1).opcode = HloOpcode::kTuple; + + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).op(0).instruction = old_root; + + op_tree.op(1).op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_root; + + op_tree.Check(new_root); +} + +// Tests while init instruction which points-to a constant. +// +// init = Tuple(Constant(S32, {}), Constant(F32, {8})) +// +// CopyInsertion pass should generate: +// +// Tuple // old init +// / \ +// GTE(0) GTE(1) +// | | +// Copy Copy +// \ / +// Tuple // new init +// +TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { + auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); + auto old_init = while_hlo->operand(0); + InsertCopies(&module_); + auto new_init = while_hlo->operand(0); + + // Check all paths from 'new_init' to 'old_init'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).instruction = old_init; + + op_tree.Check(new_init); +} + +// Tests while init instruction which points-to a parameter. +// +// init = Tuple(Constant(S32, {}), Parameter(F32, {8})) +// +// CopyInsertion pass should generate: +// +// Tuple // old init +// / \ +// GTE(0) GTE(1) +// | | +// Copy Copy +// \ / +// Tuple // new init +// +TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { + auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); + auto old_init = while_hlo->operand(0); + InsertCopies(&module_); + auto new_init = while_hlo->operand(0); + + // Check all paths from 'new_init' to 'old_init'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).instruction = old_init; + + op_tree.Check(new_init); +} + +// Tests while init instruction which has an ambiguous points-to set. +// +// select = Select(pred, tuple1, tuple2) +// init = Tuple(Constant(S32, {}), Parameter(F32, {8})) +// +// CopyInsertion pass should generate: +// +// Tuple // old init +// / \ +// / \ +// GTE(0) GTE(1) +// | / \ +// | / \ +// | GTE(0) GTE(1) +// | | | +// Copy Copy Copy +// | | | +// \ | / +// \ Tuple +// \ / +// \ / +// Tuple // new init +// +TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { + auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); + auto old_init = while_hlo->operand(0); + InsertCopies(&module_); + auto new_init = while_hlo->operand(0); + + // Check all paths from 'new_init' to 'old_init'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).opcode = HloOpcode::kTuple; + + op_tree.op(1).op(0).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init; + + op_tree.Check(new_init); +} + +// Tests while init instruction which has a non-distinct points-to set. +// +// init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one})) +// +// CopyInsertion pass should generate: +// +// Tuple // old init +// / \ +// / \ +// GTE(0) GTE(1) +// | / \ +// | / \ +// | GTE(0) GTE(1) +// | | | +// Copy Copy Copy +// | | | +// \ | / +// \ Tuple +// \ / +// \ / +// Tuple // new init +// +TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { + auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); + auto old_init = while_hlo->operand(0); + InsertCopies(&module_); + auto new_init = while_hlo->operand(0); + + // Check all paths from 'new_init' to 'old_init'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).opcode = HloOpcode::kTuple; + + op_tree.op(1).op(0).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init; + + op_tree.Check(new_init); +} + +// Tests while init instruction buffer which interfers with while result buffer. +// +// init_data = Broadcast(...) +// add_unrelated = Add(init_data) // takes a reference to cause interference +// init = Tuple(Constant(S32, {}), init_data)) +// +// CopyInsertion pass should generate: +// +// Tuple // old init +// / \ +// GTE(0) GTE(1) +// | | +// Copy Copy +// \ / +// Tuple // new init +// +TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { + auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); + auto old_init = while_hlo->operand(0); + InsertCopies(&module_); + auto new_init = while_hlo->operand(0); + + // Check all paths from 'new_init' to 'old_init'. + OperandTree op_tree; + op_tree.opcode = HloOpcode::kTuple; + + op_tree.op(0).opcode = HloOpcode::kCopy; + op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(0).op(0).op(0).instruction = old_init; + + op_tree.op(1).opcode = HloOpcode::kCopy; + op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; + op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; + op_tree.op(1).op(0).op(0).instruction = old_init; + + op_tree.Check(new_init); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD new file mode 100644 index 0000000000..8af54b11bb --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -0,0 +1,529 @@ +# Description: +# LLVM-based CPU backend for XLA. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], + features = [ + "-layering_check", + "-parse_headers", + ], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +load(":build_defs.bzl", "runtime_copts") + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "cpu_compiler", + srcs = ["cpu_compiler.cc"], + hdrs = ["cpu_compiler.h"], + deps = [ + ":compiler_functor", + ":conv_canonicalization", + ":cpu_executable", + ":cpu_instruction_fusion", + ":cpu_parallelization_preparation", + ":disassembler", + ":ir_emission_utils", + ":ir_emitter", + ":layout_assignment", + ":parallel_cpu_executable", + ":simple_orc_jit", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/port:initialize", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:buffer_liveness", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:inliner", + "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep + "//tensorflow/core:lib", # fixdeps: keep + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:aarch64_code_gen", # fixdeps: keep + "@llvm//:aarch64_disassembler", # fixdeps: keep + "@llvm//:arm_code_gen", # fixdeps: keep + "@llvm//:arm_disassembler", # fixdeps: keep + "@llvm//:core", + "@llvm//:mc", # fixdeps: keep + "@llvm//:object", + "@llvm//:powerpc_code_gen", # fixdeps: keep + "@llvm//:powerpc_disassembler", # fixdeps: keep + "@llvm//:support", + "@llvm//:target", # fixdeps: keep + "@llvm//:x86_code_gen", # fixdeps: keep + "@llvm//:x86_disassembler", # fixdeps: keep + ], + alwayslink = True, # Contains compiler registration +) + +cc_library( + name = "simple_orc_jit", + srcs = ["simple_orc_jit.cc"], + hdrs = ["simple_orc_jit.h"], + deps = [ + ":compiler_functor", + ":cpu_runtime", + ":cpu_runtime_avx", + ":cpu_runtime_sse4_1", + ":disassembler", + ":runtime_conv2d", + ":runtime_matmul", + ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_matmul", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags", + "//tensorflow/core:lib", + "@llvm//:core", + "@llvm//:mc", # fixdeps: keep + "@llvm//:orc_jit", + "@llvm//:support", + "@llvm//:target", # fixdeps: keep + ], +) + +cc_library( + name = "cpu_executable", + srcs = ["cpu_executable.cc"], + hdrs = ["cpu_executable.h"], + deps = [ + ":simple_orc_jit", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:orc_jit", + ], +) + +cc_library( + name = "parallel_cpu_executable", + srcs = ["parallel_cpu_executable.cc"], + hdrs = ["parallel_cpu_executable.h"], + deps = [ + ":cpu_runtime", + ":simple_orc_jit", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:orc_jit", + ], +) + +cc_library( + name = "ir_emitter", + srcs = ["ir_emitter.cc"], + hdrs = ["ir_emitter.h"], + deps = [ + ":cpu_runtime", + ":dot_op_emitter", + ":elemental_ir_emitter", + ":ir_emission_utils", + ":simple_orc_jit", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/core:lib", + "@llvm//:core", + "@llvm//:support", + ], +) + +cc_library( + name = "dot_op_emitter", + srcs = ["dot_op_emitter.cc"], + hdrs = ["dot_op_emitter.h"], + deps = [ + ":cpu_runtime", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_binary( + name = "sample_harness", + srcs = ["sample_harness.cc"], + deps = [ + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "disassembler", + srcs = ["disassembler.cc"], + hdrs = ["disassembler.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@llvm//:mc", + "@llvm//:mc_disassembler", + "@llvm//:object", + "@llvm//:powerpc_disassembler", # fixdeps: keep + "@llvm//:support", + "@llvm//:target", + "@llvm//:x86_disassembler", # fixdeps: keep + ], +) + +cc_library( + name = "compiler_functor", + srcs = ["compiler_functor.cc"], + hdrs = ["compiler_functor.h"], + deps = [ + ":cpu_runtime", + ":cpu_runtime_avx", + ":cpu_runtime_sse4_1", + ":disassembler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:analysis", + "@llvm//:core", + "@llvm//:ipo", + "@llvm//:mc", + "@llvm//:object", + "@llvm//:support", + "@llvm//:target", + ], +) + +cc_library( + name = "cpu_runtime_sse4_1", + srcs = ["cpu_runtime_sse4_1.cc"], + hdrs = ["cpu_runtime_sse4_1.h"], + copts = ["-DEIGEN_AVOID_STL_ARRAY"], + deps = [ + "//tensorflow/core:lib", + "//third_party/eigen3", + ], + alwayslink = True, +) + +cc_library( + name = "cpu_runtime_avx", + srcs = ["cpu_runtime_avx.cc"], + hdrs = ["cpu_runtime_avx.h"], + copts = ["-DEIGEN_AVOID_STL_ARRAY"], + deps = [ + "//tensorflow/core:lib", + "//third_party/eigen3", + ], + alwayslink = True, +) + +cc_library( + name = "cpu_runtime", + srcs = [ + "cpu_runtime.cc", + "infeed_manager.cc", + ], + hdrs = [ + "cpu_runtime.h", + "infeed_manager.h", + ], + copts = runtime_copts(), + deps = [ + "//tensorflow/compiler/xla:types", + ], +) + +cc_library( + name = "runtime_conv2d", + srcs = [ + "runtime_conv2d.cc", + "runtime_conv2d_impl.h", + ], + hdrs = ["runtime_conv2d.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_helpers", + "//third_party/eigen3", + ], +) + +cc_library( + name = "runtime_matmul", + srcs = ["runtime_matmul.cc"], + hdrs = ["runtime_matmul.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +cc_library( + name = "runtime_single_threaded_conv2d", + srcs = [ + "runtime_conv2d_impl.h", + "runtime_single_threaded_conv2d.cc", + ], + hdrs = ["runtime_single_threaded_conv2d.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_helpers", + "//third_party/eigen3", + ], +) + +cc_library( + name = "runtime_single_threaded_matmul", + srcs = ["runtime_single_threaded_matmul.cc"], + hdrs = ["runtime_single_threaded_matmul.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +cc_test( + name = "cpu_runtime_test", + srcs = ["cpu_runtime_test.cc"], + deps = [ + ":cpu_runtime", + ":runtime_matmul", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + +cc_test( + name = "infeed_manager_test", + srcs = ["infeed_manager_test.cc"], + deps = [ + ":cpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "cpu_instruction_fusion", + srcs = ["cpu_instruction_fusion.cc"], + hdrs = ["cpu_instruction_fusion.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:instruction_fusion", + ], +) + +cc_library( + name = "cpu_parallelization_preparation", + srcs = ["cpu_parallelization_preparation.cc"], + hdrs = ["cpu_parallelization_preparation.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "elemental_ir_emitter", + srcs = ["elemental_ir_emitter.cc"], + hdrs = ["elemental_ir_emitter.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm//:core", + ], +) + +cc_library( + name = "ir_emission_utils", + srcs = ["ir_emission_utils.cc"], + hdrs = ["ir_emission_utils.h"], + deps = [ + ":cpu_runtime", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/service:hlo", + ], +) + +cc_library( + name = "layout_assignment", + srcs = ["layout_assignment.cc"], + hdrs = ["layout_assignment.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "conv_canonicalization", + srcs = ["conv_canonicalization.cc"], + hdrs = ["conv_canonicalization.h"], + deps = [ + ":cpu_runtime", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + ], +) + +cc_test( + name = "conv_canonicalization_test", + srcs = ["conv_canonicalization_test.cc"], + deps = [ + ":conv_canonicalization", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl new file mode 100644 index 0000000000..b4b5219751 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl @@ -0,0 +1,11 @@ +"""build_defs for service/cpu.""" + +def runtime_copts(): + """Returns copts used for CPU runtime libraries.""" + return (["-DEIGEN_AVOID_STL_ARRAY"] + + select({ + "//tensorflow:android_arm": ["-mfpu=neon"], + "//conditions:default": []}) + + select({ + "//tensorflow:android": ["-O2"], + "//conditions:default": []})) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc new file mode 100644 index 0000000000..89b3302bca --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" + +#include +#include +#include +#include +#include +#include + +#include "external/llvm/include/llvm/ADT/StringRef.h" +#include "external/llvm/include/llvm/Analysis/TargetLibraryInfo.h" +#include "external/llvm/include/llvm/Analysis/TargetTransformInfo.h" +#include "external/llvm/include/llvm/ExecutionEngine/ObjectMemoryBuffer.h" +#include "external/llvm/include/llvm/IR/LegacyPassManager.h" +#include "external/llvm/include/llvm/IR/Verifier.h" +#include "external/llvm/include/llvm/MC/MCContext.h" +#include "external/llvm/include/llvm/Object/ObjectFile.h" +#include "external/llvm/include/llvm/Support/raw_ostream.h" +#include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "external/llvm/include/llvm/Transforms/IPO.h" +#include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h" +#include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h" +#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace cpu { + +/* static */ CompilerFunctor::VectorIntrinsics +CompilerFunctor::AllIntrinsics() { + VectorIntrinsics intrinsics; + intrinsics.sse_intrinsics = true; + intrinsics.avx_intrinsics = true; + return intrinsics; +} + +llvm::object::OwningBinary CompilerFunctor:: +operator()(llvm::Module& module) const { + llvm::legacy::PassManager module_passes; + llvm::legacy::FunctionPassManager function_passes(&module); + + VLOG(2) << "IR before optimizations"; + XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); + legacy_flags::CompilerFunctorFlags* flags = + legacy_flags::GetCompilerFunctorFlags(); + string dump_path = flags->xla_debug_cpu_dump_ir; + if (!dump_path.empty()) { + std::unique_ptr f; + TF_CHECK_OK(tensorflow::Env::Default()->NewAppendableFile(dump_path, &f)); + TF_CHECK_OK(f->Append(llvm_ir::DumpModuleToString(module))); + TF_CHECK_OK(f->Close()); + } + + // Build up optimization pipeline. + AddOptimizationPasses(&module_passes, &function_passes); + + // Run optimization passes on module. + function_passes.doInitialization(); + for (auto func = module.begin(); func != module.end(); ++func) { + function_passes.run(*func); + } + function_passes.doFinalization(); + module_passes.run(module); + + // Buffer for holding machine code prior to constructing the ObjectFile. + llvm::SmallVector stream_buffer; + llvm::raw_svector_ostream ostream(stream_buffer); + + VLOG(2) << "IR after optimizations"; + XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); + + // Generate code. + llvm::MCContext* mc_context; + llvm::legacy::PassManager codegen_passes; + target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); + codegen_passes.run(module); + + // Construct ObjectFile from machine code buffer. + std::unique_ptr memory_buffer( + new llvm::ObjectMemoryBuffer(std::move(stream_buffer))); + llvm::Expected> + object_file_or_error = llvm::object::ObjectFile::createObjectFile( + memory_buffer->getMemBufferRef()); + CHECK(object_file_or_error); + + std::unique_ptr object_file = + std::move(object_file_or_error.get()); + if (VLOG_IS_ON(2)) { + StatusOr disassembly_status = + disassembler_->DisassembleObjectFile(*object_file); + if (disassembly_status.ok()) { + XLA_VLOG_LINES(2, disassembly_status.ValueOrDie()); + } + } + + return llvm::object::OwningBinary( + std::move(object_file), std::move(memory_buffer)); +} + +namespace { +// Returns the set of vectorized library functions supported for the target. +std::vector VectorFunctionsForTargetLibraryInfoImpl( + llvm::Triple::ArchType arch, llvm::StringRef feature_string, + CompilerFunctor::VectorIntrinsics const& available_intrinsics) { + std::vector vector_functions; + + const llvm::VecDesc four_wide_vector_functions[] = { + {"expf", runtime::kExpV4F32, 4}, + {"llvm.exp.f32", runtime::kExpV4F32, 4}, + + {"logf", runtime::kLogV4F32, 4}, + {"llvm.log.f32", runtime::kLogV4F32, 4}, + + {"tanhf", runtime::kTanhV4F32, 4}, + {"llvm.tanh.f32", runtime::kTanhV4F32, 4}, + }; + + const llvm::VecDesc eight_wide_vector_functions[] = { + {"expf", runtime::kExpV8F32, 8}, + {"llvm.exp.f32", runtime::kExpV8F32, 8}, + + {"logf", runtime::kLogV8F32, 8}, + {"llvm.log.f32", runtime::kLogV8F32, 8}, + + {"tanhf", runtime::kTanhV8F32, 8}, + {"llvm.tanh.f32", runtime::kTanhV8F32, 8}, + }; + + // Our vectorized library calls are currently implement by calling into Eigen. + // As such, only emit calls to these routines if --xla_cpu_use_eigen is + // enabled. + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + if (flags->xla_cpu_use_eigen && + (arch == llvm::Triple::x86 || llvm::Triple::x86_64)) { + llvm::SmallVector features; + feature_string.split(features, ',', -1, /*KeepEmpty=*/false); + if (std::find(features.begin(), features.end(), "+sse4.1") != + features.end() && + available_intrinsics.sse_intrinsics) { + vector_functions.insert(vector_functions.end(), + std::begin(four_wide_vector_functions), + std::end(four_wide_vector_functions)); + } + if (std::find(features.begin(), features.end(), "+avx") != features.end() && + available_intrinsics.avx_intrinsics) { + vector_functions.insert(vector_functions.end(), + std::begin(eight_wide_vector_functions), + std::end(eight_wide_vector_functions)); + } + } + return vector_functions; +} +} // namespace + +void CompilerFunctor::AddOptimizationPasses( + llvm::legacy::PassManagerBase* module_passes, + llvm::legacy::FunctionPassManager* function_passes) const { + llvm::Triple target_triple(target_machine_->getTargetTriple()); + auto target_library_info_impl = + MakeUnique(target_triple); + target_library_info_impl->addVectorizableFunctions( + VectorFunctionsForTargetLibraryInfoImpl( + target_triple.getArch(), target_machine_->getTargetFeatureString(), + available_intrinsics_)); + module_passes->add( + new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); + module_passes->add(createTargetTransformInfoWrapperPass( + target_machine_->getTargetIRAnalysis())); + + module_passes->add(llvm::createVerifierPass()); + + llvm::PassManagerBuilder builder; + builder.OptLevel = opt_level_; + builder.SizeLevel = 0; + + if (opt_level_ > 1) { + builder.Inliner = llvm::createFunctionInliningPass(); + } else { + // Only inline functions marked with "alwaysinline". + builder.Inliner = llvm::createAlwaysInlinerLegacyPass(); + } + + builder.DisableUnitAtATime = false; + builder.DisableUnrollLoops = opt_level_ == 0; + builder.LoopVectorize = opt_level_ > 0; + builder.SLPVectorize = opt_level_ > 1; + + builder.populateFunctionPassManager(*function_passes); + builder.populateModulePassManager(*module_passes); + + module_passes->add(llvm::createVerifierPass()); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h new file mode 100644 index 0000000000..17dadebe97 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ + +#include "external/llvm/include/llvm/IR/LegacyPassManager.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/Object/ObjectFile.h" +#include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace cpu { + +// Functor class for compiling an LLVM module down to an object file. For use by +// Orc JIT compile layer. +class CompilerFunctor { + public: + // Describes the set of vector intrinsics available to the generated code. + struct VectorIntrinsics { + bool sse_intrinsics; + bool avx_intrinsics; + }; + + // Returns a VectorIntrinsics where all intrinsics are available. + static VectorIntrinsics AllIntrinsics(); + + explicit CompilerFunctor(llvm::TargetMachine* target_machine, + const Disassembler* disassembler, int opt_level, + const VectorIntrinsics& available_intrinsics) + : target_machine_(target_machine), + disassembler_(CHECK_NOTNULL(disassembler)), + opt_level_(opt_level), + available_intrinsics_(available_intrinsics) {} + + // Compile a Module to an ObjectFile. + llvm::object::OwningBinary operator()( + llvm::Module& module) const; // NOLINT + + private: + // Populates the given pass managers based on the optimization level. + void AddOptimizationPasses( + llvm::legacy::PassManagerBase* module_passes, + llvm::legacy::FunctionPassManager* function_passes) const; + + llvm::TargetMachine* target_machine_; + const Disassembler* disassembler_; + const unsigned opt_level_; + const VectorIntrinsics available_intrinsics_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc new file mode 100644 index 0000000000..2c3fc0abbc --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -0,0 +1,148 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" + +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace cpu { + +StatusOr ConvCanonicalization::Run(HloModule* module) { + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + if (!flags->xla_cpu_use_eigen) { + return false; + } + + bool changed = false; + for (HloInstruction* hlo : + module->entry_computation()->MakeInstructionPostOrder()) { + if (hlo->opcode() == HloOpcode::kConvolution && + !PotentiallyImplementedAsEigenConvolution(*hlo)) { + const ConvolutionDimensionNumbers& dnums = + hlo->convolution_dimension_numbers(); + auto batch_dim = dnums.batch_dimension(); + auto feature_dim = dnums.feature_dimension(); + auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); + + int num_spatial_dims = dnums.spatial_dimensions_size(); + int num_dims = num_spatial_dims + 2; + + // A canonical convolution's dimension numbers need to satisfy the + // following conditions (see cs/PotentiallyImplementedAsEigenConvolution). + // + // - the input is in NHWC or NWHC order. + // - the kernel is in HWIO or WHIO order. + // - the spatial dimensions are in the same relative order in the input, + // kernel and output. + // + // For simplicity, as a first step, we reshape the input and filter to + // NHWC and HWIO order, respectively. This may lose precision but not + // break the soundness. + HloInstruction* input = hlo->mutable_operand(0); + + std::vector new_input_dim_order(num_dims); + std::vector new_input_dims(num_dims); + new_input_dim_order[0] = batch_dim; + new_input_dims[0] = input->shape().dimensions(batch_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); + new_input_dims[i + 1] = + input->shape().dimensions(dnums.spatial_dimensions(i)); + } + new_input_dim_order[num_dims - 1] = feature_dim; + new_input_dims[num_dims - 1] = input->shape().dimensions(feature_dim); + + Shape new_input_shape = + ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims); + HloInstruction* new_input = module->entry_computation()->AddInstruction( + HloInstruction::CreateTranspose(new_input_shape, input, + new_input_dim_order)); + + HloInstruction* kernel = hlo->mutable_operand(1); + + std::vector new_kernel_dim_order(num_dims); + std::vector new_kernel_dims(num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i); + new_kernel_dims[i] = + kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i)); + } + new_kernel_dim_order[num_dims - 2] = kernel_input_feature_dim; + new_kernel_dims[num_dims - 2] = + kernel->shape().dimensions(kernel_input_feature_dim); + new_kernel_dim_order[num_dims - 1] = kernel_output_feature_dim; + new_kernel_dims[num_dims - 1] = + kernel->shape().dimensions(kernel_output_feature_dim); + + Shape new_kernel_shape = + ShapeUtil::MakeShape(kernel->shape().element_type(), new_kernel_dims); + HloInstruction* new_kernel = module->entry_computation()->AddInstruction( + HloInstruction::CreateTranspose(new_kernel_shape, kernel, + new_kernel_dim_order)); + + std::vector new_conv_dims(num_dims); + new_conv_dims[0] = hlo->shape().dimensions(batch_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + new_conv_dims[i + 1] = + hlo->shape().dimensions(dnums.spatial_dimensions(i)); + } + new_conv_dims[num_dims - 1] = hlo->shape().dimensions(feature_dim); + Shape new_conv_shape = + ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); + + ConvolutionDimensionNumbers new_dnums; + new_dnums.set_batch_dimension(0); + for (int i = 0; i < num_spatial_dims; ++i) { + new_dnums.add_spatial_dimensions(i + 1); + new_dnums.add_kernel_spatial_dimensions(i); + } + new_dnums.set_feature_dimension(num_dims - 1); + new_dnums.set_kernel_input_feature_dimension(num_dims - 2); + new_dnums.set_kernel_output_feature_dimension(num_dims - 1); + + // The window of the old convolution is reused, because reshapes only + // change the dimension mapping but not the dimension sizes. For + // example, input height and width are the same as before the reshapes. + HloInstruction* new_conv = module->entry_computation()->AddInstruction( + HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, + hlo->window(), new_dnums)); + + // kConvolution inherits the dimension mapping of its input, so we need to + // reshape the output back to the shape of the original convolution. This + // is done by apply the inverse permutation of the collapsing order of the + // input reshape. + module->entry_computation()->ReplaceWithNewInstruction( + hlo, + HloInstruction::CreateTranspose( + hlo->shape(), new_conv, InversePermutation(new_input_dim_order))); + changed = true; + } + } + + return changed; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h new file mode 100644 index 0000000000..57e17eb010 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { +namespace cpu { + +// An HLO pass that canonicalizes the dimension numbers of all top-level +// convolutions in the given module. +// +// In order to hit the fast path of using Eigen's convolution implementation, a +// convolution's dimension numbers need to satisfy certain constraints (so +// called canonical convolutions). This pass expands non-canonical convolutions +// into reshapes and canonical convolutions, so that these non-canonical +// convolutions can run faster. +class ConvCanonicalization : public HloPass { + public: + ConvCanonicalization() : HloPass("convolution-canonicalization") {} + ~ConvCanonicalization() override {} + + StatusOr Run(HloModule* module) override; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc new file mode 100644 index 0000000000..d18141af83 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +#include "tensorflow/compiler/xla/test_helpers.h" + +namespace xla { +namespace cpu { + +class ConvCanonicalizationTest : public HloTestBase { + public: + ConvCanonicalizationTest() { + for (int i = 0; i < 2; ++i) { + auto dim = conv_window_.add_dimensions(); + dim->set_size(kWindowSize); + dim->set_stride(1); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + } + + protected: + Window conv_window_; + + static constexpr int kBatchSize = 50; + static constexpr int kInputSize = 28; + static constexpr int kWindowSize = 5; + static constexpr int kInputFeatureCount = 32; + static constexpr int kOutputFeatureCount = 64; +}; + +TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { + auto builder = HloComputation::Builder(TestName()); + // The input dimensions are in CNHW order. + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D( + kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); + // The kernel dimensions are in OIHW order. + auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D( + kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(1); + dnums.add_spatial_dimensions(2); + dnums.add_spatial_dimensions(3); + dnums.set_feature_dimension(0); + dnums.add_kernel_spatial_dimensions(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(0); + auto output_size = kInputSize - kWindowSize + 1; + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape( + F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), + input, kernel, conv_window_, dnums)); + + auto module = MakeUnique(TestName()); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + ConvCanonicalization conv_canonicalization; + EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + + const HloInstruction* output_reshape = entry_computation->root_instruction(); + EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); + const HloInstruction* canonical_conv = output_reshape->operand(0); + EXPECT_EQ(HloOpcode::kConvolution, canonical_conv->opcode()); + const HloInstruction* input_reshape = canonical_conv->operand(0); + EXPECT_EQ(HloOpcode::kTranspose, input_reshape->opcode()); + const HloInstruction* kernel_reshape = canonical_conv->operand(1); + EXPECT_EQ(HloOpcode::kTranspose, kernel_reshape->opcode()); + + // The input is in CNHW order. input_reshape should produce + // NHWC for the convolution to hit the Eigen fast path. + EXPECT_TRUE(ContainersEqual(input_reshape->dimensions(), + std::vector({1, 2, 3, 0}))); + // The kernel is in OIHW order. kernel_reshape should produce + // HWIO for the convolution to hit the Eigen fast path. + EXPECT_TRUE(ContainersEqual(kernel_reshape->dimensions(), + std::vector({2, 3, 1, 0}))); + // The output of the canonical convolution is in NHWC order (the same as + // input_reshape's order). output_reshape should restore that order to the + // order of the computation root (CNHW). + EXPECT_TRUE(ContainersEqual(output_reshape->dimensions(), + std::vector({3, 0, 1, 2}))); +} + +TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { + auto builder = HloComputation::Builder(TestName()); + // The input dimensions are in NHWC order. + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D( + kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); + // The kernel dimensions are in HWIO order. + auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D( + kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(2); + dnums.set_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + auto output_size = kInputSize - kWindowSize + 1; + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape( + F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), + input, kernel, conv_window_, dnums)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + + ConvCanonicalization conv_canonicalization; + EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc new file mode 100644 index 0000000000..d566cfd8c8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -0,0 +1,631 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +// IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" +// IWYU pragma: no_include "llvm/Config/Targets.def.inc" +#include "external/llvm/include/llvm/ADT/StringRef.h" +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/Object/ObjectFile.h" +#include "external/llvm/include/llvm/Support/CommandLine.h" +#include "external/llvm/include/llvm/Support/TargetRegistry.h" +#include "external/llvm/include/llvm/Support/TargetSelect.h" +#include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "external/llvm/include/llvm/Target/TargetOptions.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/port/initialize.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" +#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" +#include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace cpu { + +CpuAotCompilationOptions::CpuAotCompilationOptions( + string triple, string cpu_name, string features, string entry_point_name, + RelocationModel relocation_model) + : triple_(std::move(triple)), + cpu_name_(std::move(cpu_name)), + features_(std::move(features)), + entry_point_name_(std::move(entry_point_name)), + relocation_model_(relocation_model) {} + +CpuAotCompilationOptions::~CpuAotCompilationOptions() = default; + +se::Platform::Id CpuAotCompilationOptions::PlatformId() const { + return se::host::kHostPlatformId; +} + +CpuAotCompilationResult::CpuAotCompilationResult( + ObjectFileData object_file_data, BufferSizes buffer_sizes, + int64 result_buffer_index) + : object_file_data_(std::move(object_file_data)), + buffer_sizes_(std::move(buffer_sizes)), + result_buffer_index_(result_buffer_index) {} + +CpuAotCompilationResult::~CpuAotCompilationResult() = default; + +CpuCompiler::CpuCompiler() { + // Initialize LLVM the first time the CpuCompiler is initialized. + static bool llvm_initialized = []() { + InitializeLLVMTarget(); + return true; + }(); + (void)llvm_initialized; +} + +/* static */ void CpuCompiler::InitializeLLVMTarget() { + // Initialize LLVM's MC layer for the native target. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + LLVMInitializeX86Target(); + LLVMInitializeX86TargetInfo(); + LLVMInitializeX86TargetMC(); + LLVMInitializeX86AsmPrinter(); + LLVMInitializeX86Disassembler(); + LLVMInitializeARMTarget(); + LLVMInitializeARMTargetInfo(); + LLVMInitializeARMTargetMC(); + LLVMInitializeARMAsmPrinter(); + LLVMInitializeARMDisassembler(); + LLVMInitializeAArch64Target(); + LLVMInitializeAArch64TargetInfo(); + LLVMInitializeAArch64TargetMC(); + LLVMInitializeAArch64AsmPrinter(); + LLVMInitializeAArch64Disassembler(); + LLVMInitializePowerPCTarget(); + LLVMInitializePowerPCTargetInfo(); + LLVMInitializePowerPCTargetMC(); + LLVMInitializePowerPCAsmPrinter(); + LLVMInitializePowerPCDisassembler(); + + // LLVM command-line flags are global, so set them during initialization. + legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + if (!flags->xla_cpu_llvm_cl_opts.empty()) { + std::vector opts = + tensorflow::str_util::Split(flags->xla_cpu_llvm_cl_opts, ','); + std::vector fake_argv; + fake_argv.push_back(""); + for (const string& opt : opts) { + fake_argv.push_back(opt.c_str()); + } + llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); + } +} + +namespace { +// This visitor records which HLO instructions should have profiling information +// recorded. +class CollectProfileCandidates : public DfsHloVisitorWithDefault { + public: + static StatusOr> + GetCandidatesForComputation(HloComputation* computation) { + std::unordered_map hlo_to_profile_idx; + CollectProfileCandidates profile_candidates_for_computation( + &hlo_to_profile_idx); + TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( + &profile_candidates_for_computation)); + return hlo_to_profile_idx; + } + + private: + explicit CollectProfileCandidates( + std::unordered_map* hlo_to_profile_idx) + : hlo_to_profile_idx_(hlo_to_profile_idx) {} + + Status DefaultAction(HloInstruction* hlo_instruction) override { + hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()}); + return Status::OK(); + } + // Skip constants, there is nothing to profile. + Status HandleConstant(HloInstruction* /*constant*/, + const Literal& /*literal*/) override { + return Status::OK(); + } + // Skip parameters, they are a simple load. + Status HandleParameter(HloInstruction* /*parameter*/) override { + return Status::OK(); + } + // It is important to recurse for "while" or else we risk overly coarse + // profiling information. + Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/, + HloComputation* condition, HloComputation* body) override { + TF_RETURN_IF_ERROR(DefaultAction(xla_while)); + + CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); + TF_RETURN_IF_ERROR( + condition->root_instruction()->Accept(&candidates_for_condition)); + + CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); + TF_RETURN_IF_ERROR(body->root_instruction()->Accept(&candidates_for_body)); + + return Status::OK(); + } + + std::unordered_map* hlo_to_profile_idx_; +}; +} // namespace + +Status CpuCompiler::RunHloPasses(HloModule* hlo_module, + HloModuleConfig* module_config, + HloDumper dump_hlo) { + // Optimization pipeline. + HloPassPipeline pipeline("CPU", dump_hlo); + pipeline.AddPass(); + pipeline.AddPass(); + { + auto& pass = pipeline.AddPass>("simplification", + dump_hlo); + pass.AddPass( + /*is_layout_sensitive=*/false, + [](const Shape&, const Shape&) { return false; }); + pass.AddPass(); + } + pipeline.AddPass(PotentiallyImplementedAsEigenDot); + pipeline.AddPass(); + pipeline.AddPass(/*is_layout_sensitive=*/false); + pipeline.AddPass(); + pipeline.AddPass( + module_config->mutable_entry_computation_layout()); + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return true; }); + pipeline.AddPass(/*is_layout_sensitive=*/true); + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). + pipeline.AddPass(); + pipeline.AddPass(); + legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + if (flags->xla_cpu_parallel) { + pipeline.AddPass(); + } + return pipeline.Run(hlo_module).status(); +} + +namespace { + +llvm::TargetOptions CompilerTargetOptions() { + llvm::TargetOptions target_options; + llvm_ir::SetTargetOptions(&target_options); + return target_options; +} + +llvm::CodeGenOpt::Level CodeGenOptLevel() { + legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + switch (flags->xla_cpu_llvm_opt_level) { + case 1: + return llvm::CodeGenOpt::Less; + case 2: + return llvm::CodeGenOpt::Default; + break; + case 3: + return llvm::CodeGenOpt::Aggressive; + break; + default: + return llvm::CodeGenOpt::None; + } +} + +// Constructs and returns a sequence for the HLO instructions in each +// computation in the given module. The sequence can be used to determine the +// order of HLO instruction emission and for buffer liveness analysis. +SequentialHloOrdering::HloModuleSequence CreateModuleSequence( + const HloModule* module) { + SequentialHloOrdering::HloModuleSequence sequence; + for (auto& computation : module->computations()) { + // Do a DFS traversal from the root to construct a sequence for each + // computation. + // TODO(b/32006145): Construct a sequence to minimize memory pressure. + std::vector order; + TF_CHECK_OK(computation->root_instruction()->Accept( + [&order](HloInstruction* instruction) { + order.push_back(instruction); + return Status::OK(); + })); + sequence.emplace(computation.get(), std::move(order)); + } + return sequence; +} + +} // namespace + +StatusOr> CpuCompiler::Compile( + std::unique_ptr hlo_module, + std::unique_ptr module_config, HloDumper dump_hlo, + se::StreamExecutor* stream_exec) { + TF_RET_CHECK(stream_exec != nullptr); + + // Compile must be thread-safe so create a new LLVM context for the module. + auto llvm_context = MakeUnique(); + auto llvm_module = + MakeUnique("__compute_module", *llvm_context); + auto jit = + MakeUnique(CompilerTargetOptions(), CodeGenOptLevel()); + llvm_module->setDataLayout(jit->data_layout()); + llvm_module->setTargetTriple(jit->target_triple().getTriple()); + const llvm::DataLayout& data_layout = llvm_module->getDataLayout(); + int64 pointer_size = data_layout.getPointerSize(); + + TF_RETURN_IF_ERROR( + RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo)); + + HloComputation* computation = hlo_module->entry_computation(); + std::unordered_map hlo_to_profile_idx; + if (module_config->hlo_profiling_enabled()) { + TF_ASSIGN_OR_RETURN( + hlo_to_profile_idx, + CollectProfileCandidates::GetCandidatesForComputation(computation)); + } + + std::unique_ptr cpu_executable; + legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + if (flags->xla_cpu_parallel) { + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + // DependencyHloOrdering is used for the parallel emitter because the order + // of HLO instruction execution is not known ahead of time. + // DependencyHloOrdering is the most conservative partial order and only + // uses data dependencies for determining order. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run(hlo_module.get(), + MakeUnique(hlo_module.get()), + pointer_size)); + + // If we are using the parallel CPU backend, we need to create map from + // HloInstruction to the corresponding generated function name. + std::map parallel_computations; + std::unordered_map> + aligned_constants; + for (auto instruction : computation->MakeInstructionPostOrder()) { + // Parameters and constants don't get their own computation. + if (instruction->opcode() == HloOpcode::kParameter) { + continue; + } + if (instruction->opcode() == HloOpcode::kConstant) { + // Copy the constant out of the ProtocolBuffer so that we can give it a + // higher alignment. + const void* data = LiteralUtil::InternalData(instruction->literal()); + int64 size = llvm_ir::ByteSizeOf(instruction->shape(), data_layout); + auto iter = aligned_constants.emplace( + instruction, MakeUnique(size)); + CHECK_EQ(iter.second, true); + unsigned char* aligned_data = iter.first->second.get(); + memcpy(aligned_data, data, size); + continue; + } + // The parallel preparation should have ensured that the top-level + // computation consists solely of Call instructions. + TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall); + HloComputation* to_apply = instruction->to_apply(); + parallel_computations.emplace(to_apply, instruction); + } + + IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, + llvm_module.get(), &hlo_to_profile_idx); + std::unique_ptr> function_names( + new std::map()); + for (auto embedded_computation : + computation->MakeEmbeddedComputationsList()) { + auto parallel_computation_iter = + parallel_computations.find(embedded_computation); + // All parallel computations are considered to be an entry computation for + // IR generation purposes. + bool computation_is_parallel = + parallel_computation_iter != parallel_computations.end(); + TF_ASSIGN_OR_RETURN( + llvm::Function * ir_function, + ir_emitter.EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_entry_computation=*/computation_is_parallel)); + // If this computation is parallel, remember it in the function name map. + // This way we know what function to execute when we try to run code for + // the Call instruction. + if (computation_is_parallel) { + HloInstruction* call_instruction = parallel_computation_iter->second; + InsertOrDie(function_names.get(), call_instruction, + llvm_ir::AsString(ir_function->getName())); + } + } + + string ir_module_string; + if (flags->xla_cpu_embed_ir) { + ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); + } + + // JIT compile the LLVM IR module to in-memory machine code. + jit->AddModule(std::move(llvm_module)); + cpu_executable.reset(new ParallelCpuExecutable( + std::move(jit), std::move(assignment), std::move(hlo_module), + std::move(module_config), std::move(function_names), + std::move(hlo_to_profile_idx), std::move(aligned_constants))); + + if (flags->xla_cpu_embed_ir) { + static_cast(*cpu_executable) + .set_ir_module_string(ir_module_string); + } + } else { + // Select an order for emitting the HLO instructions for each + // computation. Using this sequence enables tighter buffer liveness analysis + // and reduced memory usage (as compared to using DependencyHloOrdering). + SequentialHloOrdering::HloModuleSequence module_sequence = + CreateModuleSequence(hlo_module.get()); + + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run(hlo_module.get(), + MakeUnique(hlo_module.get(), + module_sequence), + pointer_size)); + + // Each computation is a single function. Emit all embedded computations + // before the entry computation. The order of computations returned from + // GetEmbeddedComputations guarantees that a called computation occurs + // before a caller computation. + IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, + llvm_module.get(), &hlo_to_profile_idx); + for (auto embedded_computation : + computation->MakeEmbeddedComputationsList()) { + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(embedded_computation, + embedded_computation->name(), + /*is_entry_computation=*/false, + &module_sequence.at(embedded_computation)) + .status()); + } + string function_name_prefix = + computation->name().empty() ? "__compute" : computation->name(); + TF_ASSIGN_OR_RETURN( + llvm::Function * entry_function, + ir_emitter.EmitComputation(computation, function_name_prefix, + /*is_entry_computation=*/true, + &module_sequence.at(computation))); + + string function_name = llvm_ir::AsString(entry_function->getName()); + string ir_module_string; + if (flags->xla_cpu_embed_ir) { + ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); + } + + // JIT compile the LLVM IR module to in-memory machine code. + jit->AddModule(std::move(llvm_module)); + cpu_executable.reset( + new CpuExecutable(std::move(jit), std::move(assignment), + std::move(hlo_module), std::move(module_config), + function_name, std::move(hlo_to_profile_idx))); + + if (flags->xla_cpu_embed_ir) { + static_cast(*cpu_executable) + .set_ir_module_string(ir_module_string); + } + } + + return std::move(cpu_executable); +} + +StatusOr>> CpuCompiler::Compile( + std::vector> hlo_modules, + std::vector> module_configs, + HloDumper dump_hlos, std::vector stream_execs) { + return Unimplemented( + "Compilation of multiple HLO modules is not yet supported on CPU."); +} + +StatusOr> CpuCompiler::CompileAheadOfTime( + std::unique_ptr hlo_module, + std::unique_ptr module_config, HloDumper dump_hlo, + const AotCompilationOptions& aot_options) { + if (aot_options.PlatformId() != se::host::kHostPlatformId) { + return InvalidArgument("Incompatible AOT compilation platform"); + } + const CpuAotCompilationOptions& options = + static_cast(aot_options); + llvm::StringRef target_triple = llvm_ir::AsStringRef(options.triple()); + llvm::Triple triple(llvm::Triple::normalize(target_triple)); + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); + if (target == nullptr) { + return InternalError("TargetRegistry::lookupTarget failed: %s", + error.c_str()); + } + + llvm::Reloc::Model reloc_model; + llvm::PICLevel::Level pic_level; + llvm::PIELevel::Level pie_level; + switch (options.relocation_model()) { + case CpuAotCompilationOptions::RelocationModel::Static: + reloc_model = llvm::Reloc::Static; + pic_level = llvm::PICLevel::NotPIC; + pie_level = llvm::PIELevel::Default; + break; + case CpuAotCompilationOptions::RelocationModel::SmallPic: + reloc_model = llvm::Reloc::PIC_; + pic_level = llvm::PICLevel::SmallPIC; + pie_level = llvm::PIELevel::Default; + break; + case CpuAotCompilationOptions::RelocationModel::BigPic: + reloc_model = llvm::Reloc::PIC_; + pic_level = llvm::PICLevel::BigPIC; + pie_level = llvm::PIELevel::Default; + break; + case CpuAotCompilationOptions::RelocationModel::SmallPie: + reloc_model = llvm::Reloc::PIC_; + pic_level = llvm::PICLevel::SmallPIC; + pie_level = llvm::PIELevel::Small; + break; + case CpuAotCompilationOptions::RelocationModel::BigPie: + reloc_model = llvm::Reloc::PIC_; + pic_level = llvm::PICLevel::BigPIC; + pie_level = llvm::PIELevel::Large; + break; + } + llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); + llvm::StringRef features = llvm_ir::AsStringRef(options.features()); + llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(); + std::unique_ptr target_machine = + WrapUnique(target->createTargetMachine( + triple.getTriple(), cpu_name, features, CompilerTargetOptions(), + reloc_model, llvm::CodeModel::Default, opt_level)); + + // Compile must be thread-safe so create a new LLVM context for the module. + llvm::LLVMContext llvm_context; + llvm::Module llvm_module("__compute_module", llvm_context); + llvm_module.setDataLayout(target_machine->createDataLayout()); + llvm_module.setTargetTriple(triple.getTriple()); + if (pic_level != llvm::PICLevel::NotPIC) { + llvm_module.setPICLevel(pic_level); + } + if (pie_level != llvm::PIELevel::Default) { + llvm_module.setPIELevel(pie_level); + } + const llvm::DataLayout& data_layout = llvm_module.getDataLayout(); + int64 pointer_size = data_layout.getPointerSize(); + + TF_RETURN_IF_ERROR( + RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo)); + + SequentialHloOrdering::HloModuleSequence module_sequence = + CreateModuleSequence(hlo_module.get()); + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run( + hlo_module.get(), + MakeUnique(hlo_module.get(), module_sequence), + pointer_size)); + + IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module, + /*hlo_to_profile_idx=*/nullptr); + HloComputation* computation = hlo_module->entry_computation(); + for (auto embedded_computation : + computation->MakeEmbeddedComputationsList()) { + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(embedded_computation, embedded_computation->name(), + /*is_entry_computation=*/false, + &module_sequence.at(embedded_computation)) + .status()); + } + const string& entry_point_name = options.entry_point_name(); + TF_ASSIGN_OR_RETURN( + llvm::Function * entry_function, + ir_emitter.EmitComputation(computation, entry_point_name, + /*is_entry_computation=*/true)); + + entry_function->setName(llvm_ir::AsStringRef(entry_point_name)); + + Disassembler disassembler(*target_machine); + CompilerFunctor compiler_functor(target_machine.get(), &disassembler, + opt_level, CompilerFunctor::AllIntrinsics()); + llvm::object::OwningBinary object_file = + compiler_functor(llvm_module); + llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); + ObjectFileData object_file_data(object_file_data_ref.begin(), + object_file_data_ref.end()); + + BufferSizes buffer_sizes; + for (const BufferAllocation& allocation : assignment->Allocations()) { + // Callers don't need to allocate temporary buffers for parameters. + if (allocation.is_entry_computation_parameter()) { + buffer_sizes.push_back(-1); + continue; + } + // Callers don't need to allocate anything for thread-local temporary + // buffers. They are lowered to allocas. + if (allocation.is_thread_local()) { + buffer_sizes.push_back(-1); + continue; + } + buffer_sizes.push_back(allocation.size()); + } + + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment->GetUniqueTopLevelOutputAllocation()); + + return std::unique_ptr( + MakeUnique(std::move(object_file_data), + std::move(buffer_sizes), + result_allocation->index())); +} + +se::Platform::Id CpuCompiler::PlatformId() const { + return se::host::kHostPlatformId; +} + +} // namespace cpu +} // namespace xla + +REGISTER_MODULE_INITIALIZER(cpu_compiler, { + xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { + return xla::MakeUnique(); + }); +}); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h new file mode 100644 index 0000000000..349724d840 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -0,0 +1,148 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_ + +#include + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace cpu { + +// This class wraps the configurability options that LLVM exposes including: the +// target triple, the target cpu and the target features. It also includes the +// desired linkage name for the computation entry point. +// Note that the optimization level can be controlled by the +// --xla_cpu_llvm_opt_level flag. +class CpuAotCompilationOptions : public AotCompilationOptions { + public: + // Relocation models available for compilation. + enum class RelocationModel { + // Corresponds to the -fno-pic compiler option. + Static, + // Corresponds to the -fpic compiler option. + SmallPic, + // Corresponds to the -fPIC compiler option. + BigPic, + // Corresponds to the -fpie compiler option. + SmallPie, + // Corresponds to the -fPIE compiler option. + BigPie + }; + + CpuAotCompilationOptions(string triple, string cpu_name, string features, + string entry_point_name, + RelocationModel relocation_model); + ~CpuAotCompilationOptions() override; + + perftools::gputools::Platform::Id PlatformId() const override; + + // The triple used for compilation, similar to clang's -target flag. + const string& triple() const { return triple_; } + // The CPU name used for compilation, similar to clang's -mcpu flag. + const string& cpu_name() const { return cpu_name_; } + // The target features used for compilation ("+avx2", "+neon", etc). + const string& features() const { return features_; } + // The name to be used for the compiled code's entry point. + const string& entry_point_name() const { return entry_point_name_; } + // The relocation model used for compilation. + RelocationModel relocation_model() const { return relocation_model_; } + + private: + const string triple_; + const string cpu_name_; + const string features_; + const string entry_point_name_; + const RelocationModel relocation_model_; +}; + +class CpuAotCompilationResult : public AotCompilationResult { + public: + CpuAotCompilationResult(ObjectFileData object_file_data, + BufferSizes buffer_sizes, int64 result_buffer_index); + ~CpuAotCompilationResult(); + + const ObjectFileData& object_file_data() const { return object_file_data_; } + const BufferSizes& buffer_sizes() const { return buffer_sizes_; } + int64 result_buffer_index() const { return result_buffer_index_; } + + private: + // Contains the compiled computation: an object file. + const ObjectFileData object_file_data_; + + // The list of buffer sizes which should be allocated in order to execute the + // compiled computation. These buffers are used for temporary buffers used + // ephemerally during computation as well as the output result. + const BufferSizes buffer_sizes_; + + // Contains which buffer index into |buffer_sizes| was designated to the + // result of the computation. This buffer should be passed into the output + // parameter when calling the compiled computation. + const int64 result_buffer_index_; +}; + +// CPU-targeting implementation of the XLA Compiler interface. +// +// The compiler translates XLA HLO code into LLVM IR and uses LLVM's JIT +// infrastructure to create an executable "blob" that can then be returned +// wrapped in CpuExecutable and actually invoked. +class CpuCompiler : public Compiler { + public: + CpuCompiler(); + ~CpuCompiler() override {} + + StatusOr> Compile( + std::unique_ptr hlo_module, + std::unique_ptr module_config, HloDumper dump_hlo, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr>> Compile( + std::vector> hlo_module, + std::vector> module_config, + HloDumper dump_hlo, + std::vector stream_exec) override; + + StatusOr> CompileAheadOfTime( + std::unique_ptr module, + std::unique_ptr module_config, HloDumper dump_hlo, + const AotCompilationOptions& options) override; + + perftools::gputools::Platform::Id PlatformId() const override; + + private: + // Initialize the LLVM target. + static void InitializeLLVMTarget(); + + // Runs the HLO passes which are necessary for both optimizations and + // correctness. + Status RunHloPasses(HloModule* hlo_module, HloModuleConfig* module_config, + HloDumper dump_hlo); + + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc new file mode 100644 index 0000000000..727257d4f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -0,0 +1,477 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" + +#include +#include +#include +#include +#include +#include + +#include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace cpu { + +CpuExecutable::CpuExecutable( + std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + std::unique_ptr module_config, + const string& entry_function_name, + std::unordered_map hlo_to_profile_idx) + : Executable(std::move(hlo_module), std::move(module_config)), + jit_(std::move(jit)), + assignment_(std::move(assignment)), + hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { + // Resolve symbols in the constructor rather than at execution time to avoid + // races because FindSymbol is not thread safe. + llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); + // We expect to find the symbol provided with entry_function_name; otherwise + // this is an internal error. + CHECK(sym) << "Symbol " << entry_function_name << " not found."; + // getAddress can do work under the hood in the jit, so it needs to be + // guarded by the mutex. + compute_function_ = reinterpret_cast(sym.getAddress()); +} + +// Given a pointer to an output buffer (following the CPU JIT calling +// conventions), mark addresses that are "live". The initial pointer itself is +// trivially live. If the shape of the buffer is a tuple, this analysis looks +// into the tuple's elements and marks them live as well (since tuples keep +// pointers to buffers) and also works recursively. address is an in-memory +// buffer address that contains some runtime XLA object. shape is its +// shape. marked_addresses is the set of live addresses to populate. +static void MarkLiveAddressesInOutput( + const void* address, const Shape& shape, + std::unordered_set* marked_addresses) { + marked_addresses->insert(address); + const uintptr_t* address_buffer = static_cast(address); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const uintptr_t* element_address = address_buffer + i; + const void* element = reinterpret_cast(*element_address); + MarkLiveAddressesInOutput( + element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); + } + } +} + +Status CpuExecutable::AllocateBuffers( + DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector* buffers) { + CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + VLOG(3) << "Allocating " << assignment_->Allocations().size() + << " allocations for module " << module().name(); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + auto& allocation = assignment_->GetAllocation(i); + + VLOG(3) << allocation.ToString(); + + if (allocation.is_entry_computation_parameter()) { + VLOG(3) << "allocation #" << i << " is a parameter"; + continue; + } + + if (allocation.is_thread_local()) { + VLOG(3) << "buffer #" << i << " is thread-local"; + continue; + } + + int64 buffer_size = allocation.size(); + if (!(*buffers)[i].is_null()) { + VLOG(3) << "buffer #" << i + << " is in the preallocated result ShapedBuffer"; + } else { + TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); + + VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" + << (*buffers)[i].opaque() << "]"; + } + + // Since the output buffer and all the temporary buffers were written into + // by the JITed code, msan has no way of knowing their memory was + // initialized. Mark them initialized so that msan doesn't flag loads from + // these buffers. + TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); + } + + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment_->GetUniqueTopLevelOutputAllocation()); + + VLOG(3) << "result index: " << result_allocation->index(); + + return Status::OK(); +} + +Status CpuExecutable::ExecuteComputeFunction( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice buffers, + HloExecutionProfile* hlo_execution_profile) { + std::vector argument_buffers; + for (int i = 0; i < arguments.size(); ++i) { + TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape())); + argument_buffers.push_back(arguments[i]->buffer(/*index=*/{})); + } + return ExecuteComputeFunction(run_options, argument_buffers, buffers, + hlo_execution_profile); +} + +Status CpuExecutable::ExecuteComputeFunction( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice buffers, + HloExecutionProfile* hlo_execution_profile) { + // The calling convention for JITed functions is: + // + // void function(void* result, const void* run_options, void** args_array, + // void** temps_array) + // + // result: Points at the result. + // run_options: the ExecutableRunOptions object. + // args_array: An array of pointers, each of which points to a parameter. + // The size of this array is determined by the function's arity + // (ProgramShape). + // temps_array: An array of pointers, each of which points to a temporary + // buffer the computation needs. The size of this array is + // determined by buffer analysis. + // + std::vector args_array; + for (se::DeviceMemoryBase arg_mem : arguments) { + args_array.push_back(arg_mem.opaque()); + } + + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + + // Allocate profiling counters for each hlo instruction that we would like to + // profile. Allocate an additional profile counter for the entire + // computation. + std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + + // Call the computation function following the calling convention. + std::vector buffer_pointers; + for (auto& buffer : buffers) { + buffer_pointers.push_back(const_cast(buffer.opaque())); + } + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment_->GetUniqueTopLevelOutputAllocation()); + void* result_buffer = buffer_pointers[result_allocation->index()]; + if (VLOG_IS_ON(3)) { + VLOG(3) << "Executing compute function:"; + VLOG(3) << tensorflow::strings::Printf( + " func(void* result, void* params[%zu], void* temps[%zu], " + "uint64 profile_counters[%zu])", + args_array.size(), buffer_pointers.size(), profile_counters.size()); + VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); + auto ptr_printer = [](string* out, const void* p) { + tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); + }; + VLOG(3) << tensorflow::strings::Printf( + " params = [%s]", + tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); + VLOG(3) << tensorflow::strings::Printf( + " temps = [%s]", + tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); + VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", + profile_counters.data()); + } + + compute_function_(result_buffer, run_options, args_array.data(), + buffer_pointers.data(), profile_counters.data()); + + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + + { + tensorflow::mutex_lock lock(mutex_); + const double nanoseconds = (end_micros - start_micros) * 1000.0; + execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + + // The last profile counter is used for the computation as a whole. + execution_profile_.set_compute_cycle_count(profile_counters.back()); + } + + if (hlo_execution_profile != nullptr) { + hlo_execution_profile->set_total_cycles_executed(profile_counters.back()); + + for (auto hlo_prof_idx : hlo_to_profile_idx_) { + const HloInstruction* hlo = hlo_prof_idx.first; + uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; + hlo_execution_profile->AddProfileResult(hlo, cycles_taken); + } + } + return Status::OK(); +} + +StatusOr CpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector buffers(assignment_->Allocations().size()); + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers, + hlo_execution_profile)); + + // Mark the buffers that are actually live (used in the output) when the + // computation finishes executing. + std::unordered_set marked_addresses; + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment_->GetUniqueTopLevelOutputAllocation()); + se::DeviceMemoryBase top_level_output = buffers[result_allocation->index()]; + MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), + &marked_addresses); + + VLOG(3) << "Live addresses in output marking found " + << marked_addresses.size() << " addresses:\n" + << tensorflow::str_util::Join( + marked_addresses, ", ", [](string* out, const void* address) { + tensorflow::strings::StrAppend( + out, tensorflow::strings::Printf("%p", address)); + }); + + // Computation is done - deallocate temp buffers. Keep those marked live + // because they are referenced by the output of the computation and are needed + // by the service. They will be deallocated by the service. + for (auto i = 0; i < buffers.size(); ++i) { + auto alloc = buffers[i]; + if (marked_addresses.count(alloc.opaque()) == 0 && + alloc.opaque() != nullptr) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR(memory_allocator->Deallocate( + stream->parent()->device_ordinal(), &alloc)); + } + } + + return top_level_output; +} + +StatusOr> CpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } + std::vector buffers(assignment_->Allocations().size()); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + ShapedBuffer::MakeShapedBuffer( + module_config().entry_computation_layout().result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal())); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers, + hlo_execution_profile)); + + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_RETURN_IF_ERROR( + result_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [&buffers, &buffers_in_result, &result_buffer, this]( + const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { + if (is_leaf) { + const std::vector& sources = + this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer + // such as a tuple element. + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, + this->assignment_->GetUniqueAllocation( + src, buffer_source->index())); + CHECK(!allocation->is_entry_computation_parameter()); + + CHECK(!buffers[allocation->index()].is_null() || + buffers[allocation->index()].size() == 0); + result_buffer->mutable_buffers()->push_back( + buffers[allocation->index()]); + *buffer_entry = result_buffer->mutable_buffers()->size() - 1; + buffers_in_result[allocation->index()] = true; + } + return Status::OK(); + })); + + // Free all buffers not in the result. + for (auto i = 0; i < buffers.size(); ++i) { + auto alloc = buffers[i]; + if (!buffers_in_result[i] && !alloc.is_null()) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR(memory_allocator->Deallocate( + stream->parent()->device_ordinal(), &alloc)); + } + } + + return std::move(result_buffer); +} + +Status CpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + // Every array element in the result of the computation must be unambiguously + // produced by a single instruction. + // This ensures that the buffers inside result_buffer can be assigned without + // conflict to the respective instructions because there is a one-to-one + // correspondence between hlo instructions and array buffers in the result. + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented( + "Points-to set of root instruction is ambiguous or not distinct"); + } + std::vector buffers(assignment_->Allocations().size()); + DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape())); + + // If two tuple elements point to the same buffer, one of the results in the + // result buffer is considered the canonical location while the other result + // points to it (instead of, say, making a copy of the result). + // buffer_index_to_shape_index maps a buffer index to its canonical location + // in the result buffer. + std::unordered_map + buffer_index_to_shape_index; + + // Copy values from result_buffer to the index in "buffers". These buffers + // will not be allocated in the call to AllocateBuffers. + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_RETURN_IF_ERROR( + result_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [&buffers, &buffers_in_result, &buffer_index_to_shape_index, + result_buffer, this](const ShapeIndex& index, bool is_leaf, + size_t* buffer_entry) { + if (is_leaf) { + const std::vector& sources = + this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer + // such as a tuple element. + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, + this->assignment_->GetUniqueAllocation( + src, buffer_source->index())); + CHECK(!allocation->is_entry_computation_parameter()); + + auto insert_result = buffer_index_to_shape_index.emplace( + allocation->index(), *buffer_entry); + if (insert_result.second) { + // The points-to set is distinct so this buffer should not + // have + // been assigned in a previous invocation of this lambda. + perftools::gputools::DeviceMemoryBase memory_base = + result_buffer->buffer(index); + CHECK(buffers[allocation->index()].is_null()); + CHECK(!memory_base.is_null()); + buffers[allocation->index()] = memory_base; + buffers_in_result[allocation->index()] = true; + } else { + // Record the fact that this tuple element is identical to + // some + // prior result. + *buffer_entry = insert_result.first->second; + } + } + return Status::OK(); + })); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers, + hlo_execution_profile)); + + // Free all buffers not in the result. + for (auto i = 0; i < buffers.size(); ++i) { + auto alloc = buffers[i]; + if (!buffers_in_result[i] && !alloc.is_null()) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR(memory_allocator->Deallocate( + stream->parent()->device_ordinal(), &alloc)); + } + } + + return Status::OK(); +} + +StatusOr +CpuExecutable::ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments) { + // TODO(b/30671675): Implement asynchronous execution mode. + return Unimplemented( + "Asynchronous execution on stream is not yet supported on CPU."); +} + +const PointsToSet& CpuExecutable::GetRootPointsToSet() const { + return assignment_->points_to_analysis().GetPointsToSet( + module().entry_computation()->root_instruction()); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h new file mode 100644 index 0000000000..8f3247e683 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// CPU-targeting implementation of the XLA Executable interface. +// +// Wraps a JIT-ed object that can be executed "on device". We JIT for the host +// architecture, so JIT-ed code and host code share the same ABI. +class CpuExecutable : public Executable { + public: + CpuExecutable( + std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + std::unique_ptr module_config, + const string& entry_function_name, + std::unordered_map hlo_to_profile_idx); + ~CpuExecutable() override {} + + StatusOr ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr> ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) override; + + Status ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments) override; + + // This should be called after set_ir_module_string. + const string& ir_module_string() const { return ir_module_string_; } + + void set_ir_module_string(const string& ir_module_string) { + ir_module_string_ = ir_module_string; + } + + private: + // Allocate buffers required for execution and assign them to the elements of + // "buffers". "buffers" should be sized to the number of buffers in buffer + // assignment. Each vector element corresponds to a particular Index. If + // a vector element already contains a non-null DeviceMemoryBase, then no + // buffer is assigned for this element. + Status AllocateBuffers( + DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector* buffers); + + // Calls the generated function performing the computation with the given + // arguments using the supplied buffers. + Status ExecuteComputeFunction( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + tensorflow::gtl::ArraySlice + buffers, + HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunction( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice + buffers, + HloExecutionProfile* hlo_execution_profile); + + // Returns the points-to set of the root instruction of the entry + // computation. Uses points-to analysis from buffer assignment. + const PointsToSet& GetRootPointsToSet() const; + + // The JIT containing compiled modules. + std::unique_ptr jit_; + + // Buffer assignment for the buffers we need to allocate. + std::unique_ptr assignment_; + + // The LLVM IR, in string format, of the unoptimized module generated for this + // CpuExecutable. We save a string instead of an llvm::Module* because leaving + // llvm::Module* in a singleton can cause the heap checker to emit false + // positives. + string ir_module_string_; + + // Type of the computation function we expect in the JIT. + // void function(void* result, const void* run_options, + // const void** args_array, void** temps_array) + using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, + uint64*); + ComputeFunctionType compute_function_; + + // Entry function name for the computation. + const string entry_function_name_; + + // Maps HLOs to their index into the profile counter array. + const std::unordered_map hlo_to_profile_idx_; + + TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc new file mode 100644 index 0000000000..240da35ef1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { +namespace cpu { + +bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, + int64 operand_index) { + HloInstruction* producer = consumer->mutable_operand(operand_index); + + // Condition for consumer: must be elementwise or a fusion op + // (which necessarily only contains elementwise operations) + if (!(consumer->opcode() == HloOpcode::kFusion || + consumer->IsElementwise())) { + return false; + } + + // Producer or consumer cannot be Map. Maps are technically elementwise but + // of a slightly different form (call instead of a computation). These are not + // yet supported in the CPU backend. + return producer->IsElementwise() && producer->operand_count() > 0 && + producer->opcode() != HloOpcode::kMap && + consumer->opcode() != HloOpcode::kMap && + InstructionFusion::ShouldFuse(consumer, operand_index); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h new file mode 100644 index 0000000000..b7c646ad47 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_INSTRUCTION_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_INSTRUCTION_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" + +namespace xla { +namespace cpu { + +class CpuInstructionFusion : public InstructionFusion { + public: + CpuInstructionFusion() {} + ~CpuInstructionFusion() override {} + + protected: + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_INSTRUCTION_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc new file mode 100644 index 0000000000..7ae81929c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace cpu { + +StatusOr ParallelizationPreparation::Run(HloModule* module) { + bool changed = false; + HloComputation* entry_computation = module->entry_computation(); + std::unordered_set outlined; + std::vector instructions_to_outline; + for (HloInstruction* instruction : + entry_computation->MakeInstructionPostOrder()) { + // If the instruction has been outlined, it no longer exists and we must not + // dereference it. + if (outlined.count(instruction) > 0) { + continue; + } + + // Skip parameters and constants, there is nothing to parallelize. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant) { + continue; + } + instructions_to_outline.clear(); + HloInstruction* outline_candidate = instruction; + instructions_to_outline.push_back(outline_candidate); + bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast; + + // Outline sole users with the current instruction. + while (outline_candidate->users().size() == 1) { + HloInstruction* prior_candidate = outline_candidate; + outline_candidate = *outline_candidate->users().begin(); + all_bitcasts |= outline_candidate->opcode() == HloOpcode::kBitcast; + if (std::any_of(outline_candidate->operands().begin(), + outline_candidate->operands().end(), + [&](const HloInstruction* operand) { + // Do not consider any candidates which have operands + // other than the prior candidate, constants or + // parameters. Otherwise, we'd increase the fan-in which + // would reduce parallelism. + return operand->opcode() != HloOpcode::kParameter && + operand->opcode() != HloOpcode::kConstant && + operand != prior_candidate; + })) { + break; + } + instructions_to_outline.push_back(outline_candidate); + } + // If all instructions in the outline candidates are a bitcast, then create + // a copy at the head of the bitcasts and include it in the outlined + // instructions. The underlying problem is that a computation which forwards + // a parameter buffer to the output is not properly handled by the backends + // or analysis. + // + // This would be better handled by being smarter about choosing outline + // candidates in the first place. + if (all_bitcasts) { + // 'head' is the first instruction in the chain of bitcasts. + HloInstruction* head = instructions_to_outline[0]; + HloInstruction* head_operand = head->mutable_operand(0); + HloInstruction* copy = + entry_computation->AddInstruction(HloInstruction::CreateUnary( + head_operand->shape(), HloOpcode::kCopy, head_operand)); + head->ReplaceOperandWith(0, copy); + instructions_to_outline.insert(instructions_to_outline.begin(), copy); + } + + outlined.insert(instructions_to_outline.begin(), + instructions_to_outline.end()); + + module->OutlineExpressionFromComputation( + instructions_to_outline, + tensorflow::strings::StrCat("computation_for_", instruction->name()), + entry_computation); + changed = true; + } + + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto& computation : module->computations()) { + HloInstruction* root = computation->root_instruction(); + // Copy root instruction if it does not define its own top-level buffer. + // TODO(b/32885001) Remove these copies (at least for the unambiguous case). + // TODO(b/32885001) Perform shallow copy if root value is a tuple. + if (!points_to_analysis->InstructionDefinesBufferAtIndex(root, + /*index=*/{})) { + HloInstruction* copy = computation->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kCopy, root)); + computation->set_root_instruction(copy); + changed = true; + } + } + return changed; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h new file mode 100644 index 0000000000..3d6cfb258f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { +namespace cpu { + +// This pass prepares an HLO module for parallel execution by transforming +// subgraphs of the top-level computation into embedded computations which can +// be executed in parallel. +// TODO(b/29630486): Currently, it is limited to turning all instructions (which +// are not constants or parameters) in the entry computation into embedded +// computations. However, it could make sense to coarsen the parallelization to +// improve cache locality. Also, we will need to do something to intelligently +// handle While constructs. +class ParallelizationPreparation : public HloPass { + public: + explicit ParallelizationPreparation() : HloPass("cpu-parallel-prepare") {} + ~ParallelizationPreparation() override {} + + // Run instruction fusion on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc new file mode 100644 index 0000000000..8e06f0520e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace cpu { +namespace runtime { + +InfeedManager* GetInfeedManager() { + static InfeedManager* manager = new InfeedManager; + return manager; +} + +} // namespace runtime +} // namespace cpu +} // namespace xla + +void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( + xla::int32 buffer_length) { + xla::cpu::runtime::InfeedManager* infeed = + xla::cpu::runtime::GetInfeedManager(); + // Wait until there's a buffer to dequeue. + xla::cpu::runtime::InfeedBuffer* buffer = infeed->BlockingDequeueBuffer(); + CHECK_EQ(buffer->length(), buffer_length); + return buffer->data(); +} + +void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, + void* buffer_ptr) { + xla::cpu::runtime::InfeedManager* infeed = + xla::cpu::runtime::GetInfeedManager(); + infeed->ReleaseCurrentBuffer(buffer_length, buffer_ptr); +} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h new file mode 100644 index 0000000000..8eae210230 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header declares functions which may be called by the generated code on +// the CPU. Calls to these functions must be resolved explicitly in the JIT in +// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context +// which is used to cache expensive state and resources utilized by the +// aforementioned functions. +// +// Other functions are declared in individual libraries as well, such as +// runtime_conv2d and runtime_matmul. As individual libraries, callers for +// ahead-of-time compilation can link only the required subset. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ + +#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace cpu { +namespace runtime { + +// Names of runtime functions. These get resolved from the generated code to the +// right symbol at link time in one of two ways: +// 1. When using the JIT, the symbol resolver (SimpleResolver in +// third_party/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc) maps +// this symbol name to +// the actual symbol. +// 2. When using ahead-of-time compilation, the linker can resolve the name +// because it is a symbol in the cpu_runtime library. +constexpr char kEigenMatmulF32SymbolName[] = "__xla_cpu_runtime_EigenMatMulF32"; +constexpr char kEigenMatmulF64SymbolName[] = "__xla_cpu_runtime_EigenMatMulF64"; +constexpr char kEigenConvF32SymbolName[] = "__xla_cpu_runtime_EigenConvF32"; +constexpr char kEigenSingleThreadedMatmulF32SymbolName[] = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; +constexpr char kEigenSingleThreadedMatmulF64SymbolName[] = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF64"; +constexpr char kEigenSingleThreadedConvF32SymbolName[] = + "__xla_cpu_runtime_EigenSingleThreadedConvF32"; +constexpr char kAcquireInfeedBufferForDequeueSymbolName[] = + "__xla_cpu_runtime_AcquireInfeedBufferForDequeue"; +constexpr char kReleaseInfeedBufferAfterDequeueSymbolName[] = + "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue"; + +// Returns the infeed manager used by the CPU runtime. +InfeedManager* GetInfeedManager(); + +} // namespace runtime +} // namespace cpu +} // namespace xla + +extern "C" { + +// Blocks until the next infeed buffer is ready to be dequeued, then +// returns it. Fails catastrophically if the next enqueued buffer is +// not of the correct length in bytes. Checking the shape rather than +// the length would be more exact, but the length check is chosen as a +// tradeoff between error checking and speed/simplicity. +extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( + xla::int32 buffer_length); + +// Relinquishes the next infeed buffer that was returned by +// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call +// completes the data at buffer_ptr may no longer be +// accessed. buffer_length must match the length passed to the call to +// __xla_cpu_runtime_AcquireInfeedBufferForDequeue that returned +// buffer_ptr. This function must be called before the next buffer is +// acquired, i.e., there may only be one outstanding infeed buffer in +// use by the runtime. TODO(b/31340454) investigate whether or not it +// is worth supporting zero-copy infeed where the buffer is retained +// by the compiled code until it has been used. If zero-copy infeed is +// implemented we will add support for multiple outstanding buffers +// that can be returned out of order. +extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + xla::int32 buffer_length, void* buffer_ptr); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc new file mode 100644 index 0000000000..646254887c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/Eigen/Core" + +namespace xla { +namespace cpu { +namespace runtime { + +#ifdef __AVX__ +V8F32 ExpV8F32(V8F32 x) { return Eigen::internal::pexp(x); } + +V8F32 LogV8F32(V8F32 x) { return Eigen::internal::plog(x); } + +V8F32 TanhV8F32(V8F32 x) { return Eigen::internal::ptanh(x); } +#endif // __AVX__ + +} // namespace runtime +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h new file mode 100644 index 0000000000..89721aaf83 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header declares functions which may be called by the generated code on +// the CPU. Calls to these functions must be resolved explicitly in the JIT in +// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context +// which is used to cache expensive state and resources utilized by the +// aforementioned functions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ + +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace cpu { +namespace runtime { + +constexpr char kExpV8F32[] = "__xla_cpu_runtime_ExpV8F32"; +constexpr char kLogV8F32[] = "__xla_cpu_runtime_LogV8F32"; +constexpr char kTanhV8F32[] = "__xla_cpu_runtime_TanhV8F32"; + +typedef float V8F32 __attribute__((__vector_size__(32))); + +// The following functions are vectorized versions of a selection of libm +// library functions. +// References to these functions are created by the LLVM vectorizer. +V8F32 ExpV8F32(V8F32 x) TF_ATTRIBUTE_WEAK; + +V8F32 LogV8F32(V8F32 x) TF_ATTRIBUTE_WEAK; + +V8F32 TanhV8F32(V8F32 x) TF_ATTRIBUTE_WEAK; + +} // namespace runtime +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc new file mode 100644 index 0000000000..69d04427c6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/Eigen/Core" + +namespace xla { +namespace cpu { +namespace runtime { + +#ifdef __SSE4_1__ + +V4F32 ExpV4F32(V4F32 x) { + Eigen::internal::Packet4f p = x; + return Eigen::internal::pexp(p); +} + +V4F32 LogV4F32(V4F32 x) { + Eigen::internal::Packet4f p = x; + return Eigen::internal::plog(p); +} + +V4F32 TanhV4F32(V4F32 x) { + Eigen::internal::Packet4f p = x; + return Eigen::internal::ptanh(p); +} + +#endif // __SSE4_1__ + +} // namespace runtime +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h new file mode 100644 index 0000000000..ded206f90a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header declares functions which may be called by the generated code on +// the CPU. Calls to these functions must be resolved explicitly in the JIT in +// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context +// which is used to cache expensive state and resources utilized by the +// aforementioned functions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ + +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace cpu { +namespace runtime { + +constexpr char kExpV4F32[] = "__xla_cpu_runtime_ExpV4F32"; +constexpr char kLogV4F32[] = "__xla_cpu_runtime_LogV4F32"; +constexpr char kTanhV4F32[] = "__xla_cpu_runtime_TanhV4F32"; + +typedef float V4F32 __attribute__((__vector_size__(16))); + +// The following functions are vectorized versions of a selection of libm +// library functions. +// References to these functions are created by the LLVM vectorizer. +V4F32 ExpV4F32(V4F32 x) TF_ATTRIBUTE_WEAK; + +V4F32 LogV4F32(V4F32 x) TF_ATTRIBUTE_WEAK; + +V4F32 TanhV4F32(V4F32 x) TF_ATTRIBUTE_WEAK; + +} // namespace runtime +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc new file mode 100644 index 0000000000..52eed7dbad --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" + +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class CpuRuntimeTest : public ::testing::Test {}; + +template +std::unique_ptr> MaybeTransposeArray2D(const Array2D& array, + bool transpose) { + int64 output_height = array.height(); + int64 output_width = array.width(); + if (transpose) { + std::swap(output_width, output_height); + } + auto output = MakeUnique>(output_height, output_width); + for (int y = 0; y < array.height(); y++) { + for (int x = 0; x < array.width(); x++) { + if (transpose) { + (*output)(x, y) = array(y, x); + } else { + (*output)(y, x) = array(y, x); + } + } + } + return output; +} + +// Verifies that matrix 'c' equals the result of matrix 'a' times matrix 'b'. +// Each element is compared to within a small error bound. +void CheckMatrixMultiply(const Array2D& a, const Array2D& b, + const Array2D& c) { + for (int i = 0; i < a.height(); ++i) { + for (int j = 0; j < b.width(); ++j) { + float sum = 0.0; + for (int k = 0; k < a.width(); ++k) { + sum += a(i, k) * b(k, j); + } + EXPECT_NEAR(sum, c(i, j), 0.01); + } + } +} + +std::unique_ptr> EigenMatrixMultiply(const Array2D& a, + const Array2D& b, + bool transpose_lhs, + bool transpose_rhs) { + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + 2); + tensorflow::EigenThreadPoolWrapper tp(&pool); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + ExecutableRunOptions run_options; + run_options.set_intra_op_thread_pool(&device); + + CHECK_EQ(a.width(), b.height()); + int64 m = a.height(); + int64 n = b.width(); + int64 k = a.width(); + + // The Eigen matmul runtime function expects the matrix to be in column major + // order and array2d is in row-major order. Create transposes of a and b. The + // 'data' buffer in the transposed array is the original array in column major + // order. + auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs); + auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs); + + // Since we're going to transpose c before returning it. Swap the order of the + // dimension sizes to ensure the returned array is properly dimensioned. + auto c_transpose = MakeUnique>(n, m); + __xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(), + a_transpose->data(), b_transpose->data(), m, + n, k, transpose_lhs, transpose_rhs); + return MaybeTransposeArray2D(*c_transpose, true); +} + +TEST_F(CpuRuntimeTest, SmallEigenMatmul) { + Array2D a({{1.0f, 2.0f}, {3.0f, 4.0f}}); + Array2D b({{5.0f, -1.0f, 3.0f}, {2.0f, 6.0f, 4.0f}}); + + for (bool transpose_lhs : {false, true}) { + for (bool transpose_rhs : {false, true}) { + auto c = EigenMatrixMultiply(a, b, transpose_lhs, transpose_rhs); + + LOG(INFO) << "a = " << a.ToString(); + LOG(INFO) << "b = " << b.ToString(); + LOG(INFO) << "c = " << c->ToString(); + + CheckMatrixMultiply(a, b, *c); + } + } +} + +TEST_F(CpuRuntimeTest, LargeEigenMatmul) { + auto a = MakeLinspaceArray2D(0.0, 1.0, 256, 512); + auto b = MakeLinspaceArray2D(-2.0, 2.0, 512, 1024); + + for (bool transpose_lhs : {false, true}) { + for (bool transpose_rhs : {false, true}) { + auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs); + + CheckMatrixMultiply(*a, *b, *c); + } + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc new file mode 100644 index 0000000000..f0dcce56b4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -0,0 +1,182 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/disassembler.h" + +#include +#include +// IWYU pragma: no_include +#include +#include + +#include "external/llvm/include/llvm/MC/MCInst.h" +#include "external/llvm/include/llvm/Support/TargetRegistry.h" +#include "external/llvm/include/llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +Disassembler::Disassembler(const llvm::TargetMachine& target_machine) + : subtarget_info_(*target_machine.getMCSubtargetInfo()) { + objfile_info_.reset(new llvm::MCObjectFileInfo()); + mc_context_.reset(new llvm::MCContext(target_machine.getMCAsmInfo(), + target_machine.getMCRegisterInfo(), + objfile_info_.get())); + disassembler_.reset(target_machine.getTarget().createMCDisassembler( + subtarget_info_, *mc_context_)); + inst_printer_.reset(target_machine.getTarget().createMCInstPrinter( + target_machine.getTargetTriple(), + /*SyntaxVariant=*/0, // Use AT&T syntax. + *target_machine.getMCAsmInfo(), *target_machine.getMCInstrInfo(), + *target_machine.getMCRegisterInfo())); + inst_analysis_.reset(target_machine.getTarget().createMCInstrAnalysis( + target_machine.getMCInstrInfo())); +} + +// This code is based on llvm-objdump in llvm/tools. +StatusOr Disassembler::DisassembleObjectFile( + const llvm::object::ObjectFile& object_file) const { + if (disassembler_ == nullptr) { + return NotFound("could not find a disassembler for this platform"); + } + + std::string buffer_string; + llvm::raw_string_ostream ostream(buffer_string); + + // Iterate through sections. Disassemble symbols of the text section(s). + for (auto& section : object_file.sections()) { + if (!section.isText()) { + continue; + } + + // Gather symbols from the section. + std::vector symbols; + for (auto& symbol : object_file.symbols()) { + if (section.containsSymbol(symbol)) { + symbols.push_back(symbol); + } + } + + // Sort the symbols in increasing address order. + std::sort( + symbols.begin(), symbols.end(), + [](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) { + // getAddress returns a Expected object. Assert there is no error + // before extracting the address. + llvm::Expected a_address_or_error = a.getAddress(); + CHECK(a_address_or_error); + llvm::Expected b_address_or_error = b.getAddress(); + CHECK(b_address_or_error); + return a_address_or_error.get() < b_address_or_error.get(); + }); + + // Construct ArrayRef pointing to section contents. + llvm::StringRef section_content_string; + if (section.getContents(section_content_string)) { + continue; + } + llvm::ArrayRef section_content_bytes( + reinterpret_cast(section_content_string.data()), + section_content_string.size()); + + // Use int types from LLVM (eg, uint64_t) for values passed to and returned + // from the LLVM API. These values map to different types in LLVM and + // XLA (unsigned long vs unsigned long long). + uint64_t section_address = section.getAddress(); + uint64_t section_size = section.getSize(); + + // Iterate through symbols in increasing address order and disassemble each + // one. + for (int i = 0; i < symbols.size(); ++i) { + auto symbol = symbols[i]; + llvm::Expected address = symbol.getAddress(); + CHECK(address); + uint64_t start_index = address.get() - section_address; + + // End of symbol is either the end of the section or the start of the next + // symbol. + uint64_t end_index; + if (i < symbols.size() - 1) { + llvm::Expected next_address = symbols[i + 1].getAddress(); + CHECK(next_address); + end_index = std::min(section_size, next_address.get()); + } else { + end_index = section_size; + } + + // Skip zero-length symbols. + if (start_index == end_index) { + continue; + } + + llvm::Expected name_or_error = symbol.getName(); + TF_RET_CHECK(name_or_error); + ostream << name_or_error.get().str() << ":\n"; + + // Disassemble symbol instruction-by-instruction. + uint64_t index = start_index; + while (index < end_index) { + llvm::MCInst instruction; + uint64_t size; + llvm::MCDisassembler::DecodeStatus decode_status = + disassembler_->getInstruction(instruction, size, + section_content_bytes.slice(index), + /*Address=*/section_address + index, + /*VStream=*/llvm::nulls(), + /*CStream=*/llvm::nulls()); + // If we fail to disassemble, then we must skip past this address. + if (size == 0) { + size = 1; + } + + ostream << tensorflow::strings::Printf("0x%08lx", index) << " "; + + if (decode_status == llvm::MCDisassembler::Success) { + // For branches, try to determine the actual address and emit it as an + // annotation. + string annotation; + if (inst_analysis_ && + (inst_analysis_->isUnconditionalBranch(instruction) || + inst_analysis_->isConditionalBranch(instruction))) { + uint64_t target; + if (inst_analysis_->evaluateBranch( + instruction, section_address + index, size, target)) { + annotation = tensorflow::strings::Printf("[0x%08lx]", target); + } + } + inst_printer_->printInst(&instruction, ostream, annotation.c_str(), + subtarget_info_); + } else { + ostream << " "; + } + + ostream << "\n"; + index += size; + } + } + } + + ostream.flush(); + return string(buffer_string.data(), buffer_string.length()); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h new file mode 100644 index 0000000000..e90f26fc82 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DISASSEMBLER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DISASSEMBLER_H_ + +#include +#include + +#include "external/llvm/include/llvm/MC/MCContext.h" +#include "external/llvm/include/llvm/MC/MCDisassembler/MCDisassembler.h" +#include "external/llvm/include/llvm/MC/MCInstPrinter.h" +#include "external/llvm/include/llvm/MC/MCInstrAnalysis.h" +#include "external/llvm/include/llvm/MC/MCObjectFileInfo.h" +#include "external/llvm/include/llvm/MC/MCSubtargetInfo.h" +#include "external/llvm/include/llvm/Object/ObjectFile.h" +#include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace cpu { + +// Class for disassembling object files (and potentially other constructs) into +// X86 assembly. Builds all the LLVM disassembly and instruction printing +// constructs from a given TargetMachine. +class Disassembler { + public: + explicit Disassembler(const llvm::TargetMachine& target_machine); + + // Returns a string containing the disassembled text sections of the given + // object file. + // + // If we couldnt' retrieve a disassembler for this platform, an error status + // is returned. + StatusOr DisassembleObjectFile( + const llvm::object::ObjectFile& object_file) const; + + private: + const llvm::MCSubtargetInfo& subtarget_info_; + std::unique_ptr objfile_info_; + std::unique_ptr mc_context_; + std::unique_ptr disassembler_; + std::unique_ptr inst_printer_; + std::unique_ptr inst_analysis_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DISASSEMBLER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc new file mode 100644 index 0000000000..420f9cebc5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -0,0 +1,346 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" + +#include +#include + +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using llvm_ir::SetToFirstInsertPoint; + +namespace cpu { + +DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, + bool transpose_rhs, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder) + : dot_(dot), + transpose_lhs_(transpose_lhs), + transpose_rhs_(transpose_rhs), + target_array_(target_array), + lhs_array_(lhs_array), + rhs_array_(rhs_array), + executable_run_options_value_(executable_run_options_value), + ir_builder_(ir_builder) {} + +/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( + const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, + const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder) { + PrimitiveType type = target_array.GetShape().element_type(); + TF_RET_CHECK(F32 == type || F64 == type); + DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, + lhs_array, rhs_array, executable_run_options_value, + ir_builder); + return dot_emitter.Emit(); +} + +bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } + +tensorflow::Status DotOpEmitter::Emit() { + // The dot operation performs a sum of products over dimension 0 of the left + // hand side operand and dimension 1 of the right hand side operand. + // + // Let the shapes of lhs and rhs be defined as below: + // + // lhs = [L{n-1} x L{n-2} x ... L{0}] + // rhs = [R{m-1} x R{m-2} x ... R{0}] + // + // The sum-of-products dimension in the lhs has size L{0} and the dimension in + // the rhs has size R{1}. Necessarily, then: + // + // L{0} == R{1} + // + // The output of the operation has the following shape: + // + // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}] + // + // To perform the operation we construct a loop nest with one for-loop for + // each dimension of the output. Inside this loop nest is another for-loop + // which performs the sum-of-products (the reduction loop) before storing + // the result in the output buffer. + + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + + if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { + // If the operands are scalar, don't emit any loops. + TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) && + ShapeUtil::IsScalar(rhs_shape)); + return EmitScalarDot(); + } + + if (PotentiallyImplementedAsEigenDot(dot_)) { + return EmitCallToRuntime(); + } + + // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special + // case where the reduction dimension is 0 for both LHS and RHS. This results + // in a vector dot product producing a scalar. + int64 lhs_reduction_dimension = 0; + if (ShapeUtil::Rank(lhs_shape) >= 2) { + lhs_reduction_dimension = + ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); + } + int64 rhs_reduction_dimension = 0; + if (ShapeUtil::Rank(rhs_shape) >= 2) { + rhs_reduction_dimension = + ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); + } + + // Verify the reduction dimension in the two operands are the same size. + TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == + rhs_shape.dimensions(rhs_reduction_dimension)); + + // Create loop nests which loop through the LHS operand dimensions and the RHS + // operand dimensions. The reduction dimension of the LHS and RHS are handled + // in a separate innermost loop which performs the sum of products. + llvm_ir::ForLoopNest loop_nest(ir_builder_); + llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( + &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs"); + llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( + &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs"); + + // Create the loop which does the sum of products reduction. + std::unique_ptr reduction_loop = loop_nest.AddLoop( + 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction"); + + // The final entry in the rhs and lhs indexes is the indvar of the + // reduction loop. + lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + + // For computing the sum of products we alloca a single location to store the + // dot product result as we accumulate it within the reduction loop. After the + // reduction loop we load the result and store into the output array. + + // Function entry basic block. + // - Emit alloca for accumulator + llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent(); + SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_); + llvm::Type* accum_type = target_array_.GetElementLlvmType(); + llvm::Value* accum_address = ir_builder_->CreateAlloca( + accum_type, /*ArraySize=*/nullptr, "accum_address"); + + // Preheader basic block of reduction loop: + // - Initialize accumulator to zero. + llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock(); + ir_builder_->SetInsertPoint(preheader_bb->getTerminator()); + + ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0), + accum_address); + + // Body basic block of reduction loop: + // - Load elements from lhs and rhs array. + // - Multiply lhs-element and rhs-element. + // - Load accumulator and add to product. + // - Store sum back into accumulator. + SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_); + + llvm::Value* lhs_element = + lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_); + llvm::Value* rhs_element = + rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_); + + llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); + llvm::Value* accum = ir_builder_->CreateLoad(accum_address); + llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product); + ir_builder_->CreateStore(updated_accum, accum_address); + + // Exit basic block of reduction loop. + // - Load accumulator value (the result). + // - Store into output array. + SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_); + + llvm::Value* result = ir_builder_->CreateLoad(accum_address); + + // Create index into target address. The target index is the concatenation of + // the rhs and lhs indexes with the reduction dimensions removed. The terms + // from the rhs index are the lower dimensions in the index so we add them + // first. + llvm_ir::IrArray::Index target_index; + for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { + if (dimension != lhs_reduction_dimension) { + target_index.push_back(lhs_index[dimension]); + } + } + for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { + if (dimension != rhs_reduction_dimension) { + target_index.push_back(rhs_index[dimension]); + } + } + + target_array_.EmitWriteArrayElement(target_index, result, ir_builder_); + + // Set the IR builder insert point to the exit basic block of the outer most + // loop. + ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + + return tensorflow::Status::OK(); +} + +tensorflow::Status DotOpEmitter::EmitScalarDot() { + // A scalar dot is just a scalar multiply. + llvm::Value* lhs_value = + lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + llvm::Value* rhs_value = + rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value); + target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); + return tensorflow::Status::OK(); +} + +tensorflow::Status DotOpEmitter::EmitCallToRuntime() { + DCHECK(ShapesAreLegalForRuntimeDot()); + + // The signature of the Eigen runtime matmul function is: + // + // (void)(void* run_options, float* out, float* lhs, float* rhs, + // int64 m, int64 n, int64 k, int32 transpose_lhs, + // int32 transpose_rhs); + // The two transpose_... parameters are actually booleans, but we use int32 + // to avoid target-dependent calling convention details. + + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + bool multi_threaded = flags->xla_cpu_multi_thread_eigen; + PrimitiveType type = target_array_.GetShape().element_type(); + llvm::Type* float_type; + const char* fn_name; + switch (type) { + case F32: + fn_name = multi_threaded + ? runtime::kEigenMatmulF32SymbolName + : runtime::kEigenSingleThreadedMatmulF32SymbolName; + float_type = ir_builder_->getFloatTy(); + break; + case F64: + fn_name = multi_threaded + ? runtime::kEigenMatmulF64SymbolName + : runtime::kEigenSingleThreadedMatmulF64SymbolName; + float_type = ir_builder_->getDoubleTy(); + break; + default: + return Unimplemented("Invalid type %s for dot operation", + PrimitiveType_Name(type).c_str()); + } + + llvm::Type* float_ptr_type = float_type->getPointerTo(); + llvm::Type* int64_type = ir_builder_->getInt64Ty(); + llvm::Type* int32_type = ir_builder_->getInt32Ty(); + llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo(); + llvm::FunctionType* matmul_type = llvm::FunctionType::get( + ir_builder_->getVoidTy(), + {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, + int64_type, int64_type, int64_type, int32_type, int32_type}, + /*isVarArg=*/false); + + llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + llvm::Module* module = function->getParent(); + + llvm::Function* matmul_func = llvm::cast( + module->getOrInsertFunction(fn_name, matmul_type)); + matmul_func->setCallingConv(llvm::CallingConv::C); + matmul_func->setDoesNotThrow(); + matmul_func->setOnlyAccessesArgMemory(); + + // The Eigen runtime function expects column-major layout. If the matrices are + // row major, then use the following identity to compute the product: + // + // (A x B)^T = B^T x A^T + // + // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. + + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0); + int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1); + int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1); + const llvm_ir::IrArray* lhs = &lhs_array_; + const llvm_ir::IrArray* rhs = &rhs_array_; + bool transpose_lhs = transpose_lhs_; + bool transpose_rhs = transpose_rhs_; + + bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0; + if (!is_column_major) { + std::swap(m, n); + std::swap(lhs, rhs); + std::swap(transpose_lhs, transpose_rhs); + } + + ir_builder_->CreateCall( + matmul_func, + {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type), + ir_builder_->CreateBitCast(target_array_.GetBasePointer(), + float_ptr_type), + ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), + ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), + ir_builder_->getInt64(m), ir_builder_->getInt64(n), + ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs), + ir_builder_->getInt32(transpose_rhs)}); + return tensorflow::Status::OK(); +} + +llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( + llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, + int64 reduction_dimension, tensorflow::StringPiece name_suffix) { + // Prepares the dimension list we will use to emit the loop nest. Outermost + // loops are added first. Add loops in major-to-minor order, and skip the + // reduction dimension. + std::vector dimensions; + const Shape& shape = operand_array.GetShape(); + for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { + int64 dimension = shape.layout().minor_to_major(i); + if (dimension != reduction_dimension) { + dimensions.push_back(dimension); + } + } + + // Create loop nest with one for-loop for each dimension of the + // output. + llvm_ir::IrArray::Index index = + loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); + // Verify every dimension except the reduction dimension was set in the index. + for (int dimension = 0; dimension < index.size(); ++dimension) { + if (dimension == reduction_dimension) { + DCHECK_EQ(nullptr, index[dimension]); + } else { + DCHECK_NE(nullptr, index[dimension]); + } + } + return index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h new file mode 100644 index 0000000000..44dfe5f2a9 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// Helper class for emitting LLVM IR to perform the dot operation. +class DotOpEmitter { + public: + // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and + // place the result in target_array. IR is emitted at current insert point of + // the builder. Upon completion of the method, the insert point is set to the + // end of all instructions emitted for this operation. + static tensorflow::Status EmitDotOperation( + const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, + const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder); + + private: + DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, + bool transpose_rhs, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder); + + // Emits the IR to perform the dot operation. + tensorflow::Status Emit(); + + // Emits instructions to perform a scalar dot product (a multiply of the + // LHS and RHS) and store the results in the target. + tensorflow::Status EmitScalarDot(); + + // Emits a call to the CPU runtime to perform the matrix multiply. + tensorflow::Status EmitCallToRuntime(); + + // Emits a series of nested loops for iterating over an operand array in the + // dot operation. Loops are constructed in major to minor dimension layout + // order. No loop is emitted for the given reduction_dimension. The function + // returns an IrArray index for the given operand_array containing the indvars + // of the loops. All dimensions of the index are filled except for the + // reduction dimension. name_suffix is the string to append to the names of + // LLVM constructs (eg, basic blocks) constructed by this method. + llvm_ir::IrArray::Index EmitOperandArrayLoopNest( + llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, + int64 reduction_dimension, tensorflow::StringPiece name_suffix); + + // Our runtime operation requires that all arrays have the same layout, + // no padding, and a rank of two. + bool ShapesAreLegalForRuntimeDot() const; + + const HloInstruction& dot_; + const bool transpose_lhs_; + const bool transpose_rhs_; + const llvm_ir::IrArray& target_array_; + const llvm_ir::IrArray& lhs_array_; + const llvm_ir::IrArray& rhs_array_; + llvm::Value* executable_run_options_value_; + llvm::IRBuilder<>* ir_builder_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc new file mode 100644 index 0000000000..9b46c35b41 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" + +#include + +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace cpu { + +StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + switch (op->opcode()) { + case HloOpcode::kTanh: { + PrimitiveType element_type = op->shape().element_type(); + string function_name; + switch (element_type) { + case F32: + function_name = "tanhf"; + break; + case F64: + function_name = "tanh"; + break; + default: + return Unimplemented("tanh"); + } + // Create function type for the function. + llvm::FunctionType* function_type = llvm::FunctionType::get( + llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_), + /*isVarArg=*/false); + // Create function declaration for 'tanhf'. + llvm::Function* function = + llvm::cast(module_->getOrInsertFunction( + llvm_ir::AsStringRef(function_name), function_type)); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create instruction to call 'tanhf'. + return ir_builder_->CreateCall(function, operand_value); + } + default: + return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); + } +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h new file mode 100644 index 0000000000..5160217674 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_ + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace cpu { + +class CpuElementalIrEmitter : public ElementalIrEmitter { + public: + CpuElementalIrEmitter(const HloModuleConfig& module_config, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) + : ElementalIrEmitter(module_config, module, ir_builder) {} + + protected: + StatusOr EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const override; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager.cc new file mode 100644 index 0000000000..23a2dfcc32 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/infeed_manager.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" + +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace cpu { +namespace runtime { + +InfeedBuffer::~InfeedBuffer() = default; + +InfeedManager::InfeedManager() : current_buffer_(nullptr) {} + +void InfeedManager::Reset() { + std::unique_lock l(mu_); + CHECK(!current_buffer_); + for (auto buffer : enqueued_buffer_) { + buffer->Done(); + } + enqueued_buffer_.clear(); +} + +void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { + std::unique_lock l(mu_); + bool was_empty = enqueued_buffer_.empty(); + enqueued_buffer_.push_back(buffer); + if (was_empty) { + // This has the potential to suffer from the notified thread + // immediately trying and failing to acquire mu_, but seems + // preferable to the alternative of notifying outside the lock + // on every enqueue. + cv_.notify_one(); + } +} + +InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { + std::unique_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + CHECK(!current_buffer_); + current_buffer_ = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + return current_buffer_; +} + +void InfeedManager::ReleaseCurrentBuffer(int32 length, void* data) { + std::unique_lock l(mu_); + CHECK(current_buffer_); + CHECK_EQ(length, current_buffer_->length()); + CHECK_EQ(data, current_buffer_->data()); + current_buffer_->Done(); + current_buffer_ = nullptr; +} + +} // namespace runtime +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.h b/tensorflow/compiler/xla/service/cpu/infeed_manager.h new file mode 100644 index 0000000000..298729f31f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/infeed_manager.h @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header declares the abstract class for the infeed manager that +// is used by the CPU runtime to transfer buffers into an executing +// CPU computation, e.g., to feed data into a while loop. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ + +// TODO(misard) Adding NOLINT because as soon as XLA is +// open-sourced this will use the tensorflow wrapper classes. +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace cpu { +namespace runtime { + +// Abstract class defining an infeed buffer that is passed to the +// runtime by a client. The client manages the storage of the buffer. +class InfeedBuffer { + public: + virtual ~InfeedBuffer(); + + virtual int32 length() = 0; + virtual void* data() = 0; + virtual void Done() = 0; +}; + +// Client-side class used to enqueue infeed buffers. +class InfeedManager { + public: + InfeedManager(); + + // Calls the completion callback for any enqueued buffers that have + // not been dequeued by the runtime, and empties the infeed + // queue. Reset may not be called while a runtime computation is + // processing a dequeued buffer. The only safe way to ensure this + // condition is to call Reset when no computation is taking place. + void Reset(); + + // Adds buffer to the infeed queue. buffer->Done will be called when + // the buffer will no longer be accessed by the InfeedManager, + // either as a result of a call to Reset or because the runtime has + // dequeued and used the buffer. + void EnqueueBuffer(InfeedBuffer* buffer); + + // Blocks until the infeed queue is non-empty, then returns the + // buffer at the head of the queue. Sets the current buffer to be + // the returned buffer. It is an error to call BlockingDequeueBuffer + // if there is an unreleased current buffer, i.e., + // ReleaseCurrentBuffer must be called between calls to + // BlockingDequeueBuffer. + InfeedBuffer* BlockingDequeueBuffer(); + + // Releases the current buffer, which is the last buffer returned by + // BlockingDequeuBuffer and not yet released. length and data must + // match the buffer->length() and buffer->data() for the current + // buffer. + void ReleaseCurrentBuffer(int32 length, void* data); + + private: + std::mutex mu_; + // Condition variable that is signaled every time a buffer is + // enqueued to an empty queue. + std::condition_variable cv_; + // InfeedBuffer* queue contents are not owned, but buffer->Done must + // be called when the buffer is no longer needed by the runtime. + std::deque enqueued_buffer_; + // If non-NULL, the buffer that is currently being processed by the + // runtime. Not owned. + InfeedBuffer* current_buffer_; +}; + +} // namespace runtime +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc new file mode 100644 index 0000000000..c65d821660 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" + +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class InfeedManagerTest : public ::testing::Test {}; + +class TestInfeedBuffer : public cpu::runtime::InfeedBuffer { + public: + explicit TestInfeedBuffer(int32 length) + : done_called_(false), length_(length) {} + ~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); } + + int32 length() override { return length_; } + void* data() override { return nullptr; } + void Done() override { + CHECK(!done_called_); + done_called_ = true; + } + + private: + bool done_called_; + int32 length_; +}; + +void ProcessNextBuffer(int32 length) { + void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(length); + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer); +} + +TEST_F(InfeedManagerTest, SingleThreadedSequential) { + TestInfeedBuffer* a = new TestInfeedBuffer(64); + TestInfeedBuffer* b = new TestInfeedBuffer(32); + + cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); + + infeed->EnqueueBuffer(a); + infeed->EnqueueBuffer(b); + ProcessNextBuffer(a->length()); + ProcessNextBuffer(b->length()); +} + +TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { + TestInfeedBuffer* a = new TestInfeedBuffer(64); + TestInfeedBuffer* b = new TestInfeedBuffer(32); + + cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); + + infeed->EnqueueBuffer(a); + ProcessNextBuffer(a->length()); + infeed->EnqueueBuffer(b); + ProcessNextBuffer(b->length()); +} + +TEST_F(InfeedManagerTest, MultiThreaded) { + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); + + cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); + + const int32 length = 64; + + pool.Schedule([infeed]() { + // Spin for 100 milliseconds + int64 start_micros = tensorflow::Env::Default()->NowMicros(); + while (true) { + int64 end_micros = tensorflow::Env::Default()->NowMicros(); + if ((end_micros - start_micros) >= 100000) { // 100 ms + break; + } + } + TestInfeedBuffer* a = new TestInfeedBuffer(length); + infeed->EnqueueBuffer(a); + }); + + ProcessNextBuffer(length); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc new file mode 100644 index 0000000000..2d855d0eb1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace cpu { + +bool PotentiallyImplementedAsEigenConvolution( + const HloInstruction& convolution) { + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + if (!flags->xla_cpu_use_eigen) { + return false; + } + + // The following conditions are necessary (but not sufficient) for + // implementing `convolution` with Eigen convolution: + // - the input and kernel have a non-zero number of elements. + // - the input is in NHWC or NWHC order. + // - the kernel is in HWIO or WHIO order. + // - the spatial dimensions are in the same relative order in the input, + // kernel and output. + // + // To be sufficient, certain layout constraints need to be satisfied as well. + if (ShapeUtil::HasZeroElements(convolution.operand(0)->shape()) || + ShapeUtil::HasZeroElements(convolution.operand(1)->shape())) { + return false; + } + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + // Only 2D convolutions are supported at the moment. + // TODO(b/32897908): add an optimized implementation for 3D convolution. + if (dnums.spatial_dimensions_size() != 2) { + return false; + } + bool input_spatial_dims_ascending = + dnums.spatial_dimensions(0) < dnums.spatial_dimensions(1); + bool kernel_spatial_dims_ascending = + dnums.kernel_spatial_dimensions(0) < dnums.kernel_spatial_dimensions(1); + return dnums.batch_dimension() == 0 && dnums.feature_dimension() == 3 && + input_spatial_dims_ascending == kernel_spatial_dims_ascending && + dnums.kernel_input_feature_dimension() == 2 && + dnums.kernel_output_feature_dimension() == 3; +} + +namespace { + +// Return whether the given shape is a matrix with no padding. +bool IsRank2WithNoPadding(const Shape& shape) { + return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape) { + // The inputs and the output must + // 1) be matrices with no padding, and + // 2) have an allowed element type. + return output_shape.element_type() == F32 && + IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape); +} +} // namespace + +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + if (!flags->xla_cpu_use_eigen) { + return false; + } + + // For certain types of Dot, we can call Eigen + if (hlo.opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + return true; + } + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && + hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + return true; + } + + return false; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h new file mode 100644 index 0000000000..d48646d116 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +namespace cpu { + +bool PotentiallyImplementedAsEigenConvolution( + const HloInstruction& convolution); + +bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc new file mode 100644 index 0000000000..7a839169fc --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -0,0 +1,1774 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Constants.h" +#include "external/llvm/include/llvm/IR/GlobalVariable.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Intrinsics.h" +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +using llvm_ir::SetToFirstInsertPoint; + +namespace cpu { + +IrEmitter::IrEmitter( + const HloModule& hlo_module, const HloModuleConfig& hlo_module_config, + const BufferAssignment& assignment, llvm::Module* llvm_module, + const std::unordered_map* hlo_to_profile_idx) + : assignment_(assignment), + module_(llvm_module), + arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), + ir_builder_(llvm_module->getContext()), + hlo_to_profile_idx_(hlo_to_profile_idx), + alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), + hlo_module_config_(hlo_module_config) { + llvm::FastMathFlags fast_math_flags; + llvm_ir::SetFastMathFlags(&fast_math_flags); + ir_builder_.setFastMathFlags(fast_math_flags); +} + +StatusOr IrEmitter::EmitComputation( + HloComputation* computation, const string& function_name_prefix, + bool is_entry_computation, + std::vector* instruction_order) { + string function_name = name_uniquer_.GetUniqueName(function_name_prefix); + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; + InitializeIrFunction(function_name, is_entry_computation); + // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic + // readcyclecounter if it is unavailable. + bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || + arch_type_ == llvm::Triple::ArchType::x86_64; + profiling_state_ = ProfilingState(is_entry_computation, use_rdtscp, + GetProfileCountersArgument()); + if (instruction_order != nullptr) { + TF_RETURN_IF_ERROR(computation->root_instruction()->AcceptOrdered( + this, *instruction_order)); + } else { + TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); + } + InsertOrDie(&emitted_functions_, computation, compute_function_); + + return compute_function_; +} + +static llvm::Argument* GetArg(llvm::Function* f, int idx) { + llvm::Function::arg_iterator arg_iter = f->arg_begin(); + std::advance(arg_iter, idx); + return &*arg_iter; +} + +void IrEmitter::InitializeIrFunction(const string& function_name, + bool is_entry_computation) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // (elided for aot) \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (hlo_to_profile_idx_) { + compute_function_params.push_back(i64_ptr_type); + } + llvm::FunctionType* compute_function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module_->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + llvm::Function::LinkageTypes linkage = + is_entry_computation ? llvm::GlobalValue::ExternalLinkage + : llvm::GlobalValue::InternalLinkage; + compute_function_ = llvm::Function::Create(/*Ty=*/compute_function_type, + /*Linkage=*/linkage, + /*Name=*/function_name.c_str(), + /*Module=*/module_); + compute_function_->setCallingConv(llvm::CallingConv::C); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); + arg_iter->setName("retval"); + (++arg_iter)->setName("run_options"); + (++arg_iter)->setName("params"); + (++arg_iter)->setName("temps"); + if (hlo_to_profile_idx_) { + (++arg_iter)->setName("prof_counters"); + } + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = GetResultArgument(); + for (llvm::Argument& argument : compute_function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + compute_function_->setDoesNotAlias(argument.getArgNo() + 1); + } + + ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/module_->getContext(), + /*Name=*/"entry", + /*Parent=*/compute_function_)); +} + +IrEmitter::~IrEmitter() {} + +Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { + VLOG(2) << "HandleBitcast: " << bitcast->ToString(); + emitted_value_[bitcast] = ir_builder_.CreateBitCast( + GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), bitcast->name().c_str()); + return Status::OK(); +} + +Status IrEmitter::HandleConstant(HloInstruction* constant, + const Literal& literal) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + emitted_value_[constant] = global_for_const; + VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); + VLOG(2) << " its type: " + << llvm_ir::DumpToString(*global_for_const->getType()); + return Status::OK(); +} + +Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) { + if (ShapeUtil::IsTuple(copy->shape())) { + // kCopy shallow copies a tuple so just memcpy the top-level buffer. + TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); + emitted_value_[copy] = copy_value; + return EmitMemcpy(*operand, *copy); + } else { + // Use the elemental emitter for non-tuple shapes. + return DefaultAction(copy); + } +} + +// Calculate the alignment of a buffer with a particular size. +int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { + // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on + // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for + // allocations smaller than 16 bytes and at least alignment 16 for allocations + // greater than or equal to 16 bytes. N.B. We could improve on this lower + // bound by explicitly allocating the memory with posix_memalign. This is + // complicated by our desire to allow parameter buffers created by clients to + // be consumed directly by the JIT. + if (buffer_size == 0) { + // No need to align empty buffers. + return 1; + } + int pointer_size = module_->getDataLayout().getPointerSize(); + int buffer_alignment = buffer_size >= 16 ? 2 * pointer_size : 8; + DCHECK_GT(buffer_alignment, 0); + + return buffer_alignment; +} + +// Calculate the alignment of a buffer allocated for a given primitive type. +int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { + int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + DCHECK_GE(buffer_size, 0); + DCHECK_LE(buffer_size, SIZE_MAX); + + return MinimumAlignmentForBufferSize(buffer_size); +} + +int64 IrEmitter::ByteSizeOf(const Shape& shape) const { + return llvm_ir::ByteSizeOf(shape, module_->getDataLayout()); +} + +// Calculate the alignment of a buffer allocated for a given shape. +int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { + int64 buffer_size = ByteSizeOf(shape); + DCHECK_GE(buffer_size, 0); + DCHECK_LE(buffer_size, SIZE_MAX); + + return MinimumAlignmentForBufferSize(buffer_size); +} + +void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, + const Shape& shape) { + int alignment = MinimumAlignmentForShape(shape); + if (alignment > 1) { + llvm_ir::SetAlignmentMetadataForLoad(load, alignment); + } +} + +void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, + int64 buffer_size) { + int alignment = MinimumAlignmentForBufferSize(buffer_size); + if (alignment > 1) { + llvm_ir::SetAlignmentMetadataForLoad(load, alignment); + } +} + +void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, + const Shape& shape) { + AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape)); +} + +void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, + int64 buffer_size) { + if (buffer_size > 0) { + llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size); + } +} + +Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) { + // A tuple is an array of pointers, one for each operand. Each pointer points + // to the output buffer of its corresponding operand. A GetTupleElement + // instruction forwards a pointer to the tuple element buffer at the given + // index. + const Shape& shape = get_tuple_element->shape(); + emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( + shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), + GetEmittedValueFor(operand), &ir_builder_); + return Status::OK(); +} + +Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + TF_RET_CHECK(pred->shape().element_type() == PRED); + + if (ShapeUtil::IsTuple(select->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * output_address, + EmitTargetAddressForOp(select)); + llvm_ir::EmitTupleSelect(llvm_ir::IrArray(output_address, select->shape()), + GetIrArrayForOp(pred), GetEmittedValueFor(on_true), + GetEmittedValueFor(on_false), &ir_builder_); + emitted_value_[select] = output_address; + return Status::OK(); + } + + return DefaultAction(select); +} + +Status IrEmitter::HandleInfeed(HloInstruction* infeed) { + VLOG(2) << "HandleInfeed: " << infeed->ToString(); + + // The signature of the acquire infeed buffer function is: + // + // (void*)(int32 length); + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* int32_type = ir_builder_.getInt32Ty(); + llvm::FunctionType* acquire_type = + llvm::FunctionType::get(i8_ptr_type, {int32_type}, + /*isVarArg=*/false); + + llvm::Function* acquire_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + acquire_func->setCallingConv(llvm::CallingConv::C); + + // The signature of the release infeed buffer function is: + // + // (void)(int32 length, void* buffer); + llvm::FunctionType* release_type = llvm::FunctionType::get( + ir_builder_.getVoidTy(), {int32_type, i8_ptr_type}, + /*isVarArg=*/false); + + llvm::Function* release_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + release_func->setCallingConv(llvm::CallingConv::C); + + const Shape& shape = infeed->shape(); + int64 length = ByteSizeOf(shape); + if (length > std::numeric_limits::max()) { + return InvalidArgument("infeed buffer length %lld is too large", length); + } + int32 length_32 = static_cast(length); + + llvm::Value* acquired_pointer = + ir_builder_.CreateCall(acquire_func, {ir_builder_.getInt32(length_32)}); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(infeed)); + + ir_builder_.CreateMemCpy(target_address, acquired_pointer, length_32, 1); + + ir_builder_.CreateCall(release_func, + {ir_builder_.getInt32(length_32), acquired_pointer}); + + emitted_value_[infeed] = target_address; + + return Status::OK(); +} + +Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { + // TODO(b/26783907): Implement sort on CPU. + return Unimplemented("sort"); +} + +Status IrEmitter::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(tuple)); + std::vector base_ptrs; + for (auto operand : operands) { + base_ptrs.push_back(GetEmittedValueFor(operand)); + } + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()), + base_ptrs, &ir_builder_); + emitted_value_[tuple] = target_address; + return Status::OK(); +} + +Status IrEmitter::HandleMap( + HloInstruction* map, tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice /*static_operands*/) { + // The called computation should have been emitted previously. + llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); + + return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function]( + const llvm_ir::IrArray::Index& index) { + std::vector parameter_addresses; + for (const HloInstruction* operand : operands) { + const llvm_ir::IrArray& array = GetIrArrayForOp(operand); + parameter_addresses.push_back( + array.EmitArrayElementAddress(index, &ir_builder_)); + } + return EmitElementFunctionCall(mapped_ir_function, map->shape(), + parameter_addresses, "map_function"); + }); +} + +Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, + const Window& window, + HloComputation* function) { + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*reduce_window, /*operands=*/{operand}, + /*supported_types=*/{F32})); + + // TODO(b/31410564): Implement dilation for reduce-window. + if (window_util::HasDilation(window)) { + return Unimplemented( + "Dilation for reduce-window not implemented on CPU. See b/31410564."); + } + + // The called computation should have been emitted previously. + llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); + + // Pseudo code for reduce window: + // + // for (coordinates O in the output) + // value = init_value; + // for (coordinates W in the window) + // for each index i: + // input coordinates I_i = O_i * stride_i + W_i - pad_low_i + // if I within bounds of input: + // value = function(value, input(I)); + // output(O) = value; + // + // This is completely un-optimized and just here to have something + // that works. + return EmitTargetElementLoop( + reduce_window, [this, reduce_window, operand, window, + reducer_function](const llvm_ir::IrArray::Index& index) { + // We fold inputs into the accumulator and initialize it to + // the initial value on the reduce_window. + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + "reduce_window_accumulator_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(operand_element_type)); + ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor( + reduce_window->operand(1))), + accumulator_address); + + llvm_ir::ForLoopNest loops(&ir_builder_); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + llvm_ir::IrArray::Index input_index(index.size()); + llvm::Value* in_bounds_condition = nullptr; + for (int64 i = 0; i < index.size(); ++i) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + index[i], ir_builder_.getInt64(window.dimensions(i).stride())); + input_index[i] = ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, window_index[i]), + ir_builder_.getInt64(window.dimensions(i).padding_low())); + + // We need to check if 0 <= input_index[i] < bound, as + // otherwise we are in the padding so that we can skip the + // computation. That is equivalent to input_index[i] < bound + // as an *unsigned* comparison, since a negative value will + // wrap to a large positive value. + llvm::Value* index_condition = ir_builder_.CreateICmpULT( + input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension( + operand->shape(), i))); + if (in_bounds_condition == nullptr) { + in_bounds_condition = index_condition; + } else { + in_bounds_condition = + ir_builder_.CreateAnd(in_bounds_condition, index_condition); + } + } + CHECK(in_bounds_condition != nullptr); + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + in_bounds_condition, "in-bounds", &ir_builder_); + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + + // We are not in the padding, so carry out the computation. + llvm_ir::IrArray input_array(GetIrArrayForOp(operand)); + llvm::Value* input_value_address = + input_array.EmitArrayElementAddress(input_index, &ir_builder_); + llvm::Value* result = EmitElementFunctionCall( + reducer_function, reduce_window->shape(), + {accumulator_address, input_value_address}, "reducer_function"); + ir_builder_.CreateStore(result, accumulator_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + return ir_builder_.CreateLoad(accumulator_address); + }); +} + +Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { + CHECK_EQ(select_and_scatter->operand_count(), 3); + const auto operand = select_and_scatter->operand(0); + const auto source = select_and_scatter->operand(1); + const auto init_value = select_and_scatter->operand(2); + const Window& window = select_and_scatter->window(); + PrimitiveType operand_element_type = operand->shape().element_type(); + const int64 rank = ShapeUtil::Rank(operand->shape()); + CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + CHECK_EQ(rank, window.dimensions_size()); + + // TODO(b/31410564): Implement dilation for select-and-scatter. + if (window_util::HasDilation(window)) { + return Unimplemented( + "Dilation for select-and-scatter not implemented on CPU. " + "See b/31410564."); + } + + // The select and scatter computations should have been emitted previously. + llvm::Function* select_function = + FindOrDie(emitted_functions_, select_and_scatter->select()); + llvm::Function* scatter_function = + FindOrDie(emitted_functions_, select_and_scatter->scatter()); + + // Pseudo code for select-and-scatter: + // + // initialized_flag is initially off for every window, and is turned on after + // the first iteration is completed and the first operand value is selected. + // + // output(*) = init_value + // for (coordinates S in the source) { + // initialized_flag = false + // for (coordinates W in the window) { + // I = S * stride + W - pad_low + // if I within bounds of operand: + // if !initialized_flag or select(selected_value, operand(I)) == false: + // selected_value = operand(I) + // selected_index = I + // initialized_flag = true + // } + // output(selected_index) = scatter(output(selected_index), source(S)) + // } + // + + // Initialize the output array with the given init_value. + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + select_and_scatter, + [this, init_value](const llvm_ir::IrArray::Index& target_index) { + llvm::Value* init_value_addr = GetEmittedValueFor(init_value); + return ir_builder_.CreateLoad(init_value_addr); + })); + + // Create a loop to iterate over the source array to scatter to the output. + llvm_ir::ForLoopNest source_loops(&ir_builder_); + const llvm_ir::IrArray::Index source_index = + source_loops.AddLoopsForShape(source->shape(), "source"); + SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + // Allocate space to keep the currently selected value, its index, and + // the boolean initialized_flag, which is initially set to false. + llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + "selected_value_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(operand_element_type)); + llvm::Value* selected_index_address = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), + "selected_index_address", &ir_builder_); + llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( + ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); + ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address); + + // Create the inner loop to iterate over the window. + llvm_ir::ForLoopNest window_loops(&ir_builder_); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + // Compute the operand index to visit and evaluate the condition whether the + // operand index is within the bounds. The unsigned comparison includes + // checking whether the operand index >= 0. + llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); + operand_index[i] = ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, window_index[i]), + ir_builder_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ir_builder_.CreateICmpULT( + operand_index[i], + ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = + ir_builder_.CreateAnd(in_bounds_condition, index_condition); + } + CHECK(in_bounds_condition != nullptr); + + // Only need to do something if the operand index is within the bounds. First + // check if the initialized_flag is set. + llvm_ir::LlvmIfData if_in_bounds = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); + SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); + llvm_ir::LlvmIfData if_initialized = + llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address), + "initialized", &ir_builder_); + + // If the initialized_flag is false, initialize the selected value and index + // with the currently visiting operand. + SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); + const auto save_operand_index = [&]( + const llvm_ir::IrArray::Index& operand_index) { + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( + selected_index_address, {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); + } + }; + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm::Value* operand_data = + operand_array.EmitReadArrayElement(operand_index, &ir_builder_); + ir_builder_.CreateStore(operand_data, selected_value_address); + save_operand_index(operand_index); + ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address); + + // If the initialized_flag is true, call the `select` function to potentially + // update the selected value and index with the currently visiting operand. + SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); + const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); + llvm::Value* operand_address = + operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); + llvm::Value* result = EmitElementFunctionCall( + select_function, output_shape, {selected_value_address, operand_address}, + "select_function"); + + // If the 'select' function returns false, update the selected value and the + // index to the currently visiting operand. + llvm::Value* cond = ir_builder_.CreateICmpNE( + result, llvm::ConstantInt::get( + llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_select_lhs = + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); + SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); + ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), + selected_value_address); + save_operand_index(operand_index); + + // After iterating over the window elements, scatter the source element to + // the selected index of the output. The value we store at the output + // location is computed by calling the `scatter` function with the source + // value and the current output value. + SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), + &ir_builder_); + llvm_ir::IrArray::Index selected_index; + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( + selected_index_address, {ir_builder_.getInt32(i)}); + selected_index.push_back( + ir_builder_.CreateLoad(selected_index_address_slot)); + } + llvm_ir::IrArray source_array(GetIrArrayForOp(source)); + llvm::Value* source_value_address = + source_array.EmitArrayElementAddress(source_index, &ir_builder_); + llvm_ir::IrArray output_array(GetIrArrayForOp(select_and_scatter)); + llvm::Value* output_value_address = + output_array.EmitArrayElementAddress(selected_index, &ir_builder_); + llvm::Value* scatter_value = EmitElementFunctionCall( + scatter_function, source->shape(), + {output_value_address, source_value_address}, "scatter_function"); + output_array.EmitWriteArrayElement(selected_index, scatter_value, + &ir_builder_); + + SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), + &ir_builder_); + return Status::OK(); +} + +Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) { + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*dot, /*operands=*/{lhs, rhs}, + /*supported_types=*/{F32, F64})); + + llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + + Shape target_shape = dot->shape(); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dot)); + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*dot, &target_array); + + VLOG(2) << "HandleDot: "; + VLOG(2) << " lhs operand: " + << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); + VLOG(2) << " rhs operand: " + << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); + VLOG(2) << " target: " + << llvm_ir::DumpToString(*target_array.GetBasePointer()); + + // Dot operation is complicated so we delegate to a helper class. + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, + lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_)); + + emitted_value_[dot] = target_address; + return Status::OK(); +} + +Status IrEmitter::HandleConvolution(HloInstruction* convolution, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window) { + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, + /*supported_types=*/{F32})); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + if (PotentiallyImplementedAsEigenConvolution(*convolution)) { + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const Shape& convolution_shape = convolution->shape(); + // The input, kernel and output agree with respect to layout. + if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) && + LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) && + LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) { + llvm::Value* lhs_address = GetEmittedValueFor(lhs); + llvm::Value* rhs_address = GetEmittedValueFor(rhs); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(convolution)); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + // Input tensor. + const Shape& input_shape = convolution->operand(0)->shape(); + int64 input_batch = input_shape.dimensions(dnums.batch_dimension()); + int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); + int64 input_cols = input_shape.dimensions(dnums.spatial_dimensions(1)); + int64 input_channels = input_shape.dimensions(dnums.feature_dimension()); + + // Kernel tensor. + const Shape& kernel_shape = convolution->operand(1)->shape(); + int64 kernel_rows = + kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0)); + int64 kernel_cols = + kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1)); + int64 kernel_channels = + kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); + int64 kernel_filters = + kernel_shape.dimensions(dnums.kernel_output_feature_dimension()); + + // Output tensor. + const Shape& convolution_shape = convolution->shape(); + int64 output_rows = + convolution_shape.dimensions(dnums.spatial_dimensions(0)); + int64 output_cols = + convolution_shape.dimensions(dnums.spatial_dimensions(1)); + + // Extract the window stride for the convolution. + const Window& window = convolution->window(); + int64 row_stride = window.dimensions(0).stride(); + int64 col_stride = window.dimensions(1).stride(); + + int64 padding_top = window.dimensions(0).padding_low(); + int64 padding_bottom = window.dimensions(0).padding_high(); + int64 padding_left = window.dimensions(1).padding_low(); + int64 padding_right = window.dimensions(1).padding_high(); + + int64 lhs_row_dilation = window.dimensions(0).base_dilation(); + int64 lhs_col_dilation = window.dimensions(1).base_dilation(); + int64 rhs_row_dilation = window.dimensions(0).window_dilation(); + int64 rhs_col_dilation = window.dimensions(1).window_dilation(); + + // Args have been computed, make the call. + llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo(); + llvm::Type* int64_type = ir_builder_.getInt64Ty(); + llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); + llvm::FunctionType* conv_type = llvm::FunctionType::get( + ir_builder_.getVoidTy(), + {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, + int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type}, + /*isVarArg=*/false); + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + const char* fn_name = + (flags->xla_cpu_multi_thread_eigen + ? runtime::kEigenConvF32SymbolName + : runtime::kEigenSingleThreadedConvF32SymbolName); + llvm::Function* conv_func = llvm::cast( + module_->getOrInsertFunction(fn_name, conv_type)); + conv_func->setCallingConv(llvm::CallingConv::C); + conv_func->setDoesNotThrow(); + conv_func->setOnlyAccessesArgMemory(); + ir_builder_.CreateCall( + conv_func, + { + GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast(target_address, float_ptr_type), + ir_builder_.CreateBitCast(lhs_address, float_ptr_type), + ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(input_rows), + ir_builder_.getInt64(input_cols), + ir_builder_.getInt64(input_channels), + ir_builder_.getInt64(kernel_rows), + ir_builder_.getInt64(kernel_cols), + ir_builder_.getInt64(kernel_channels), + ir_builder_.getInt64(kernel_filters), + ir_builder_.getInt64(output_rows), + ir_builder_.getInt64(output_cols), + ir_builder_.getInt64(row_stride), + ir_builder_.getInt64(col_stride), + ir_builder_.getInt64(padding_top), + ir_builder_.getInt64(padding_bottom), + ir_builder_.getInt64(padding_left), + ir_builder_.getInt64(padding_right), + ir_builder_.getInt64(lhs_row_dilation), + ir_builder_.getInt64(lhs_col_dilation), + ir_builder_.getInt64(rhs_row_dilation), + ir_builder_.getInt64(rhs_col_dilation), + }); + emitted_value_[convolution] = target_address; + + return Status::OK(); + } + } + + // This is a completely un-optimized version of convolution just to + // have an early version that works. E.g. the input index and + // padding calculation is not hoisted out of the inner loop. + // + // See the description of convolution in the XLA documentation for the pseudo + // code for convolution. + return EmitTargetElementLoop( + convolution, [this, convolution, lhs, rhs, window, + dnums](const llvm_ir::IrArray::Index& index) { + int num_spatial_dims = dnums.spatial_dimensions_size(); + std::vector output_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + output_spatial[i] = index[dnums.spatial_dimensions(i)]; + } + llvm::Value* output_feature = index[dnums.feature_dimension()]; + llvm::Value* batch = index[dnums.batch_dimension()]; + + // We will accumulate the products into this sum to calculate + // the output entry at the given index. + PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_), + "convolution_sum_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(lhs_element_type)); + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); + + llvm_ir::ForLoopNest loops(&ir_builder_); + std::vector kernel_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_spatial[i] = + loops + .AddLoop(0, rhs->shape().dimensions( + dnums.kernel_spatial_dimensions(i)), + tensorflow::strings::StrCat("k", i)) + ->GetIndVarValue(); + } + llvm::Value* input_feature = + loops + .AddLoop(0, lhs->shape().dimensions(dnums.feature_dimension()), + "iz") + ->GetIndVarValue(); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Calculate the spatial index in the input array, taking striding, + // dilation and padding into account. An index in the padding will be + // out of the bounds of the array. + const auto calculate_input_index = [this]( + llvm::Value* output_index, llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + output_index, ir_builder_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( + kernel_index, ir_builder_.getInt64(window_dim.window_dilation())); + return ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), + ir_builder_.getInt64(window_dim.padding_low())); + }; + std::vector input_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = calculate_input_index( + output_spatial[i], kernel_spatial[i], window.dimensions(i)); + } + + // We need to check if 0 <= input dim < bound, as otherwise we are in + // the padding so that we can skip the computation. That is equivalent + // to input dim < bound as an *unsigned* comparison, since a negative + // value will wrap to a large positive value. The input dim is dilated, + // so we need to dilate the bound as well to match. + + // Also need to check that the input coordinates are not in one of the + // holes created by base dilation. + const auto not_in_hole = [&](llvm::Value* input_index, + int64 base_dilation) { + llvm::Value* remainder = ir_builder_.CreateSRem( + input_index, ir_builder_.getInt64(base_dilation)); + return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); + }; + + llvm::Value* in_bounds_condition = nullptr; + for (int i = 0; i < num_spatial_dims; ++i) { + llvm::ConstantInt* input_bound = + ir_builder_.getInt64(window_util::DilatedBound( + lhs->shape().dimensions(dnums.spatial_dimensions(i)), + window.dimensions(i).base_dilation())); + llvm::Value* dim_in_bound = + ir_builder_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_not_in_hole = not_in_hole( + input_spatial[i], window.dimensions(i).base_dilation()); + llvm::Value* dim_ok = + ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); + in_bounds_condition = + in_bounds_condition + ? ir_builder_.CreateAnd(in_bounds_condition, dim_ok) + : dim_ok; + } + + // Now we need to map the dilated base coordinates back to the actual + // data indices on the lhs. + const auto undilate = [&](llvm::Value* input_index, + int64 base_dilation) { + return ir_builder_.CreateSDiv(input_index, + ir_builder_.getInt64(base_dilation)); + }; + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = + undilate(input_spatial[i], window.dimensions(i).base_dilation()); + } + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + in_bounds_condition, "in-bounds", &ir_builder_); + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + + // We are not in the padding, so carry out the computation. + int num_dims = num_spatial_dims + 2; + llvm_ir::IrArray::Index input_index(num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; + } + input_index[dnums.feature_dimension()] = input_feature; + input_index[dnums.batch_dimension()] = batch; + + llvm_ir::IrArray kernel_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray::Index kernel_index(num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; + } + kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; + kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; + + llvm_ir::IrArray input_array(GetIrArrayForOp(lhs)); + llvm::Value* product = ir_builder_.CreateFMul( + input_array.EmitReadArrayElement(input_index, &ir_builder_), + kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); + llvm::Value* sum = ir_builder_.CreateFAdd( + ir_builder_.CreateLoad(sum_address), product); + ir_builder_.CreateStore(sum, sum_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + return ir_builder_.CreateLoad(sum_address); + }); +} + +Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { + // TODO(b/33011107): Support cross replica sum on CPU. + return Unimplemented( + "Cross replica sum not implemented on CPU. See b/33011107."); +} + +Status IrEmitter::HandleParameter(HloInstruction* parameter) { + VLOG(2) << "HandleParameter: " << parameter->ToString(); + auto param_number = parameter->parameter_number(); + auto param_shape = parameter->shape(); + + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); + llvm::LoadInst* param_address_untyped = + ir_builder_.CreateLoad(param_address_offset); + llvm::Value* param_address_typed = ir_builder_.CreateBitCast( + param_address_untyped, IrShapeType(param_shape)->getPointerTo()); + emitted_value_[parameter] = param_address_typed; + + // Parameters of different types may not alias one another. + llvm_ir::SetTbaaForInstruction(param_address_untyped, param_shape, + /*is_pointer_to=*/true); + if (!ShapeUtil::IsOpaque(param_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); + } + + VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); + return Status::OK(); +} + +Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) { + // The called computation should have been emitted previously. + llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); + return EmitTargetElementLoop( + reduce, [this, reduce, arg, init_value, dimensions, + reducer_function](const llvm_ir::IrArray::Index& index) { + // Initialize an accumulator with init_value. + PrimitiveType accumulator_type = reduce->shape().element_type(); + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_), + "accumulator", &ir_builder_, + MinimumAlignmentForPrimitiveType(accumulator_type)); + llvm::Value* init_value_addr = GetEmittedValueFor(init_value); + llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr); + ir_builder_.CreateStore(load_init_value, accumulator_addr); + + // The enclosing loops go over all the target elements. Now we have to + // compute the actual target element. For this, we build a new loop nest + // to iterate over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction + // Value*s are placed for each dimension in dimensions, and all the rest + // are nullptrs. + llvm_ir::ForLoopNest loops(&ir_builder_); + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Build a full index for the input argument, using reduced_dims_index + // as the base. In reduced_dims_index only the reduction dimensions are + // filled in. We fill in the rest of the dimensions with induction + // Value*s taken from 'index' which iterates over the target array. + // See the high-level description in the XLA documentation for details. + llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray::Index input_index = reduced_dims_index; + llvm_ir::IrArray::Index::const_iterator it = index.begin(); + + for (int64 i = 0; i < input_index.size(); ++i) { + if (input_index[i] == nullptr) { + input_index[i] = *it++; + } + } + CHECK(index.end() == it); + + // Apply the reduction function to the loaded value. + llvm::Value* input_address = + arg_array.EmitArrayElementAddress(input_index, &ir_builder_); + llvm::Value* result = EmitElementFunctionCall( + reducer_function, reduce->shape(), + {accumulator_addr, input_address}, "reduce_function"); + ir_builder_.CreateStore(result, accumulator_addr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + return ir_builder_.CreateLoad(accumulator_addr); + }); +} + +Status IrEmitter::HandleSend(HloInstruction* send) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Send is not implemented on CPU. See b/33942983."); +} + +Status IrEmitter::HandleRecv(HloInstruction* recv) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Recv is not implemented on CPU. See b/33942983."); +} + +Status IrEmitter::HandlePad(HloInstruction* pad) { + // First, fill in the padding value to all output elements. + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + pad, [this, pad](const llvm_ir::IrArray::Index& target_index) { + const HloInstruction* padding_value = pad->operand(1); + llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); + return ir_builder_.CreateLoad(padding_value_addr); + })); + + // Create a loop to iterate over the operand elements and update the output + // locations where the operand elements should be stored. + llvm_ir::ForLoopNest loops(&ir_builder_); + const HloInstruction* operand = pad->operand(0); + const llvm_ir::IrArray::Index operand_index = + loops.AddLoopsForShape(operand->shape(), "operand"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Load an element from the operand. + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm::Value* operand_data = + operand_array.EmitReadArrayElement(operand_index, &ir_builder_); + + // Compute the output index the operand element should be assigned to. + // output_index := edge_padding_low + operand_index * (interior_padding + 1) + const PaddingConfig& padding_config = pad->padding_config(); + llvm_ir::IrArray::Index output_index; + for (int64 i = 0; i < operand_index.size(); ++i) { + llvm::Value* offset = ir_builder_.CreateMul( + operand_index[i], + ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + + 1)); + llvm::Value* index = ir_builder_.CreateAdd( + offset, + ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low())); + output_index.push_back(index); + } + + // Store the operand element to the computed output location. + llvm_ir::IrArray output_array(GetIrArrayForOp(pad)); + output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + return Status::OK(); +} + +// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. +static const HloInstruction* StripTranspose(const HloInstruction& hlo) { + if (hlo.IsRank2Transpose()) { + return hlo.operand(0); + } + return &hlo; +} + +Status IrEmitter::HandleFusion(HloInstruction* fusion) { + if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { + const HloInstruction* dot = fusion->fused_expression_root(); + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + fusion->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + fusion->operand(rhs_parameter->parameter_number()); + + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*dot, /*operands=*/{lhs, rhs}, + /*supported_types=*/{F32})); + + llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + + Shape target_shape = fusion->shape(); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*fusion, &target_array); + + VLOG(2) << "HandleFusion kTransposeDot: "; + VLOG(2) << " lhs operand: " + << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); + VLOG(2) << " rhs operand: " + << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); + VLOG(2) << " target: " + << llvm_ir::DumpToString(*target_array.GetBasePointer()); + + // Dot operation is complicated so we delegate to a helper class. + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + *dot, dot->operand(0)->IsRank2Transpose(), + dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, + GetExecutableRunOptionsArgument(), &ir_builder_)); + + emitted_value_[fusion] = target_address; + return Status::OK(); + } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { + std::vector parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + parameter_arrays.push_back(GetIrArrayForOp(operand)); + } + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_, + module_); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); + + return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); + } else { + return Unimplemented("Fusion kind not implemented on CPU"); + } +} + +Status IrEmitter::HandleCall( + HloInstruction* call, tensorflow::gtl::ArraySlice operands, + HloComputation* computation) { + llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); + + std::vector parameter_addresses; + for (HloInstruction* operand : operands) { + parameter_addresses.push_back(GetEmittedValueFor(operand)); + } + + TF_ASSIGN_OR_RETURN(llvm::Value * output_address, + EmitTargetAddressForOp(call)); + + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + output_address, computation->name()); + + emitted_value_[call] = output_address; + return Status::OK(); +} + +Status IrEmitter::HandleCustomCall( + HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) { + llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); + llvm::AllocaInst* operands_alloca = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + i8_ptr_type, ir_builder_.getInt32(operands.size()), + "cc_operands_alloca", &ir_builder_); + for (int i = 0; i < operands.size(); ++i) { + const HloInstruction* operand = operands[i]; + llvm::Value* operand_as_i8ptr = + ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP( + operands_alloca, {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + } + auto* custom_call_ir_function = + llvm::cast(module_->getOrInsertFunction( + llvm_ir::AsStringRef(custom_call_target), + llvm::FunctionType::get( + /*Result=*/ir_builder_.getVoidTy(), + /*Params=*/{i8_ptr_type, operands_alloca->getType()}, + /*isVarArg=*/false))); + + TF_ASSIGN_OR_RETURN(llvm::Value * output_address, + EmitTargetAddressForOp(custom_call)); + auto* output_address_arg = + ir_builder_.CreatePointerCast(output_address, i8_ptr_type); + + ir_builder_.CreateCall(custom_call_ir_function, + {output_address_arg, operands_alloca}); + + emitted_value_[custom_call] = output_address; + return Status::OK(); +} + +Status IrEmitter::HandleWhile(HloInstruction* xla_while, HloInstruction* init, + HloComputation* condition, HloComputation* body) { + // Precondition: Condition computation must return a scalar bool. + TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && + condition->root_instruction()->shape().element_type() == PRED) + << "While condition computation must return bool"; + // Check that all while-related buffers share an allocation. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( + xla_while->shape(), + [this, &xla_while](const Shape& /*subshape*/, + const ShapeIndex& index) -> Status { + auto check = [this](const HloInstruction* a, const HloInstruction* b, + const ShapeIndex& index) { + BufferAllocation::Index index_a = + assignment_.GetUniqueAllocation(a, index) + .ConsumeValueOrDie() + ->index(); + BufferAllocation::Index index_b = + assignment_.GetUniqueAllocation(b, index) + .ConsumeValueOrDie() + ->index(); + if (index_a != index_b) { + return InternalError( + "instruction %s does not share allocation with " + "instruction %s ", + a->ToString().c_str(), b->ToString().c_str()); + } + return Status::OK(); + }; + TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index)); + TF_RETURN_IF_ERROR(check( + xla_while, xla_while->while_condition()->parameter_instruction(0), + index)); + TF_RETURN_IF_ERROR( + check(xla_while, xla_while->while_body()->parameter_instruction(0), + index)); + TF_RETURN_IF_ERROR(check( + xla_while, xla_while->while_body()->root_instruction(), index)); + return Status::OK(); + })); + + // Set emitted value to that of 'init' with which it shares an allocation. + emitted_value_[xla_while] = GetEmittedValueFor(init); + + // The called computation should have been emitted previously. + llvm::Function* condition_ir_function = + FindOrDie(emitted_functions_, condition); + llvm::Function* body_ir_function = FindOrDie(emitted_functions_, body); + + // Generating: + // while (Condition(while_result)) { + // // CopyInsertion pass inserts copies which enable 'while_result' to + // // be passed back in as 'Body' parameter. + // while_result = Body(while_result); // Insert + // } + + // Terminates the current block with a branch to a while header. + llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( + module_->getContext(), "while_header", compute_function_); + ir_builder_.CreateBr(header_bb); + ir_builder_.SetInsertPoint(header_bb); + + // Calls the condition function to determine whether to proceed with the + // body. It must return a bool, so use the scalar call form. + llvm::Value* while_result = GetEmittedValueFor(xla_while); + llvm::Value* while_condition = EmitElementFunctionCall( + condition_ir_function, condition->root_instruction()->shape(), + {while_result}, "condition_function"); + llvm::Value* while_predicate = ir_builder_.CreateICmpNE( + while_condition, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), + 0)); + + // Branches to the body or to the while exit depending on the condition. + llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( + module_->getContext(), "while_body", compute_function_); + llvm::BasicBlock* exit_bb = + llvm::BasicBlock::Create(module_->getContext(), "while__exit"); + ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); + + // Calls the body function from the body block. + ir_builder_.SetInsertPoint(body_bb); + + // Calls the body function. + EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, + "while_body"); + // Finishes with a branch back to the header. + ir_builder_.CreateBr(header_bb); + + // Adds the exit block to the function and sets the insert point there. + compute_function_->getBasicBlockList().push_back(exit_bb); + ir_builder_.SetInsertPoint(exit_bb); + + return Status::OK(); +} + +Status IrEmitter::FinishVisit(HloInstruction* root) { + // When this method is called, we should have already emitted an IR value for + // the root (return) op. The IR value holds the address of the buffer holding + // the value. If the root is a constant or parameter, we perform a memcpy from + // this buffer to the retval buffer of the computation. Otherwise, there's + // nothing to do since the result was already written directly into the output + // buffer. + VLOG(2) << "FinishVisit root: " << root->ToString(); + llvm::Value* root_value = GetEmittedValueFor(root); + VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); + + if (auto* prof_counter = GetProfileCounterFor(/*hlo=*/nullptr)) { + profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); + } + + ir_builder_.CreateRetVoid(); + return Status::OK(); +} + +llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) { + string counter_name; + size_t prof_counter_idx; + if (!hlo_to_profile_idx_) { + return nullptr; + } + if (hlo) { + auto it = hlo_to_profile_idx_->find(hlo); + if (it == hlo_to_profile_idx_->end()) { + return nullptr; + } + + prof_counter_idx = it->second; + uintptr_t hlo_address = reinterpret_cast(hlo); + counter_name = tensorflow::strings::StrCat( + "prof_counter_0x", + tensorflow::strings::Hex( + hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address)))); + } else { + prof_counter_idx = hlo_to_profile_idx_->size(); + counter_name = "prof_counter_computation"; + } + return ir_builder_.CreateGEP(GetProfileCountersArgument(), + ir_builder_.getInt64(prof_counter_idx), + llvm_ir::AsStringRef(counter_name)); +} + +void IrEmitter::ProfilingState::UpdateProfileCounter( + llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, + llvm::Value* cycle_end, llvm::Value* cycle_start) { + auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start); + llvm::LoadInst* old_cycle_count = + ir_builder->CreateLoad(prof_counter, "old_cycle_count"); + auto* new_cycle_count = + ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); + ir_builder->CreateStore(new_cycle_count, prof_counter); +} + +llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter( + llvm::IRBuilder<>* ir_builder) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + if (use_rdtscp_) { + llvm::Function* func_llvm_readcyclecounter = + llvm::Intrinsic::getDeclaration(module, + llvm::Intrinsic::readcyclecounter); + return ir_builder->CreateCall(func_llvm_readcyclecounter); + } + llvm::Function* func_llvm_x86_rdtscp = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp); + if (!aux_i8ptr_) { + llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry( + ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder); + aux_i8ptr_ = + ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy()); + } + llvm::ConstantInt* alloca_size = ir_builder->getInt64(4); + llvm::Function* func_llvm_lifetime_start = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start); + ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); + llvm::Value* rdtscp_call = + ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); + llvm::Function* func_llvm_lifetime_end = + llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end); + ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); + return rdtscp_call; +} + +void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder, + HloInstruction* hlo) { + auto* cycle_start = ReadCycleCounter(ir_builder); + cycle_starts_[hlo] = cycle_start; + if (first_read_cycle_start_ == nullptr) { + first_read_cycle_start_ = cycle_start; + } +} + +void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder, + HloInstruction* hlo, + llvm::Value* prof_counter) { + auto* cycle_end = ReadCycleCounter(ir_builder); + auto* cycle_start = cycle_starts_[hlo]; + UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start); + last_read_cycle_end_ = cycle_end; +} + +void IrEmitter::ProfilingState::RecordCompleteComputation( + llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) { + if (is_entry_computation_ && last_read_cycle_end_ && + first_read_cycle_start_) { + UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_, + first_read_cycle_start_); + } +} + +Status IrEmitter::Preprocess(HloInstruction* hlo) { + if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) { + profiling_state_.RecordCycleStart(&ir_builder_, hlo); + } + return Status::OK(); +} + +Status IrEmitter::Postprocess(HloInstruction* hlo) { + if (auto* prof_counter = GetProfileCounterFor(hlo)) { + profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); + } + return Status::OK(); +} + +llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { + llvm::Value* value_for_op = GetEmittedValueFor(hlo); + + llvm_ir::IrArray array(value_for_op, hlo->shape()); + AddAliasingInformationToIrArray(*hlo, &array); + return array; +} + +llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { + auto it = emitted_value_.find(hlo); + if (it == emitted_value_.end()) { + LOG(FATAL) << "could not find emitted value for: " << hlo->ToString(); + } + return it->second; +} + +llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { + return llvm_ir::ShapeToIrType(shape, &ir_builder_); +} + +llvm::Argument* IrEmitter::GetResultArgument() { + return GetArg(compute_function_, 0); +} + +llvm::Argument* IrEmitter::GetProfileCountersArgument() { + return hlo_to_profile_idx_ ? GetArg(compute_function_, 4) : nullptr; +} + +llvm::Value* IrEmitter::GetTempBuffersArgument() { + return GetArg(compute_function_, 3); +} + +llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { + return GetArg(compute_function_, 1); +} + +llvm::Value* IrEmitter::EmitTempBufferPointer( + BufferAllocation::Index temp_buf_index, const Shape& target_shape) { + llvm::Type* element_type = IrShapeType(target_shape); + // The alignment and number of bytes within the temporary buffer is determined + // by the maximal shape as determined by buffer assignment. + const BufferAllocation& allocation = + assignment_.GetAllocation(temp_buf_index); + if (allocation.is_thread_local()) { + // Thread-local allocations should only be assigned a single buffer. + CHECK_EQ(1, allocation.assigned_buffers().size()); + const Shape& shape = allocation.assigned_buffers()[0]->shape(); + + llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{ + ir_builder_.GetInsertBlock()->getParent(), temp_buf_index}]; + if (tempbuf_address == nullptr) { + tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( + IrShapeType(shape), + tensorflow::strings::StrCat("thread_local", temp_buf_index), + &ir_builder_, MinimumAlignmentForShape(target_shape)); + } + return ir_builder_.CreateBitCast(tempbuf_address, + element_type->getPointerTo()); + } + + llvm::Value* tempbuf_address_offset = llvm_ir::EmitBufferIndexingGEP( + GetTempBuffersArgument(), temp_buf_index, &ir_builder_); + llvm::LoadInst* tempbuf_address_untyped = + ir_builder_.CreateLoad(tempbuf_address_offset); + // Loading the address of a buffer is invariant of the point at which the + // load is executed in the program because we never reassign buffers. + tempbuf_address_untyped->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(tempbuf_address_untyped->getContext(), /*MDs=*/{})); + llvm_ir::SetTbaaForInstruction(tempbuf_address_untyped, target_shape, + /*is_pointer_to=*/true); + + AttachAlignmentMetadataForLoad(tempbuf_address_untyped, allocation.size()); + AttachDereferenceableMetadataForLoad(tempbuf_address_untyped, + allocation.size()); + return ir_builder_.CreateBitCast(tempbuf_address_untyped, + element_type->getPointerTo()); +} + +// Emits a function call returning a single array element. Allocates space +// for a single element_type value, and loads it after call. +llvm::Value* IrEmitter::EmitElementFunctionCall( + llvm::Function* function, const Shape& return_shape, + tensorflow::gtl::ArraySlice parameter_addresses, + tensorflow::StringPiece name) { + llvm::Value* return_value_buffer = EmitArrayFunctionCall( + function, return_shape, 1, parameter_addresses, name); + return ir_builder_.CreateLoad( + return_value_buffer, + llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder_.getInt8PtrTy(), + ir_builder_.getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + &ir_builder_); + for (int i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( + parameter_addresses[i], ir_builder_.getInt8PtrTy(), + llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + } + + const auto to_int8_ptr = [this](llvm::Value* ptr) { + return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); + }; + std::vector arguments{ + to_int8_ptr(return_value_buffer), + to_int8_ptr(GetExecutableRunOptionsArgument()), + parameter_addresses_buffer, GetTempBuffersArgument()}; + if (auto* profile_counters = GetProfileCountersArgument()) { + arguments.push_back(profile_counters); + } + ir_builder_.CreateCall(function, arguments); +} + +llvm::Value* IrEmitter::EmitArrayFunctionCall( + llvm::Function* function, const Shape& return_shape, int64 element_count, + tensorflow::gtl::ArraySlice parameter_addresses, + tensorflow::StringPiece name) { + llvm::Value* elements = + llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count); + PrimitiveType return_type = return_shape.element_type(); + llvm::Value* return_value_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements, + tensorflow::strings::StrCat(name, "_return_value_address"), + &ir_builder_, MinimumAlignmentForPrimitiveType(return_type)); + EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, + name); + return return_value_buffer; +} + +StatusOr IrEmitter::EmitTargetAddressForOp( + const HloInstruction* op) { + const Shape& target_shape = op->shape(); + if (op == op->parent()->root_instruction()) { + // For the root node, we write directly to the output buffer of the + // function. + llvm::Argument* retval = GetResultArgument(); + if (!ShapeUtil::HasZeroElements(target_shape)) { + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttr(llvm::AttributeSet::get( + retval->getContext(), retval->getArgNo() + 1, attr_builder)); + } + return ir_builder_.CreateBitCast(retval, + IrShapeType(target_shape)->getPointerTo()); + } + + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, + assignment_.GetUniqueTopLevelAllocation(op)); + return EmitTempBufferPointer(allocation->index(), target_shape); +} + +Status IrEmitter::EmitTargetElementLoop( + HloInstruction* target_op, + const llvm_ir::ElementGenerator& element_generator) { + VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); + + // target_address will hold the address of the target buffer we will write the + // result of the computation into. + const Shape& target_shape = target_op->shape(); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(target_op)); + VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*target_op, &target_array); + + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) + .EmitLoop()); + emitted_value_[target_op] = target_address; + return Status::OK(); +} + +Status IrEmitter::EmitMemcpy(const HloInstruction& source, + const HloInstruction& destination) { + llvm::Value* source_value = GetEmittedValueFor(&source); + llvm::Value* destination_value = GetEmittedValueFor(&destination); + int64 source_size = ByteSizeOf(source.shape()); + ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); + return Status::OK(); +} + +Status IrEmitter::ElementTypesSameAndSupported( + const HloInstruction& instruction, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice supported_types) { + for (auto operand : operands) { + TF_RET_CHECK( + ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); + } + + TF_RET_CHECK(!operands.empty()); + PrimitiveType primitive_type = operands[0]->shape().element_type(); + if (std::find(supported_types.begin(), supported_types.end(), + primitive_type) == supported_types.end()) { + return Unimplemented("unsupported operand type %s in op %s", + PrimitiveType_Name(primitive_type).c_str(), + HloOpcodeString(instruction.opcode()).c_str()); + } + return Status::OK(); +} + +Status IrEmitter::DefaultAction(HloInstruction* hlo) { + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : hlo->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_); + }; + } + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_, + module_); + return EmitTargetElementLoop( + hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h new file mode 100644 index 0000000000..06415c735d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -0,0 +1,402 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ + +#include +#include +#include +#include +#include + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It +// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR +// functions. +class IrEmitter : public DfsHloVisitorWithDefault { + public: + // Create a new LLVM IR emitter. + // + // hlo_module: the HLO module we are emitting IR for. + // assignment: a BufferAssignment from which we know which temporary buffers + // are used by the HLO nodes. + // llvm_module: the LLVM module to emit IR into. + // hlo_to_profile_idx: the mapping from HLO to its index in the profiling + // array. + IrEmitter(const HloModule& hlo_module, const HloModuleConfig& module_config, + const BufferAssignment& assignment, llvm::Module* llvm_module, + const std::unordered_map* + hlo_to_profile_idx); + ~IrEmitter() override; + + // Emit and return the given HLO computation as an LLVM IR + // function. function_name_prefix is the desired name of the function. If the + // name is not unique among already emitted functions then a suffix is + // appended to make the name unique. is_entry_computation indicates that this + // is the entry computation of the HLO module. If 'instruction_order' is given + // then the HLO instructions are emitted in the given order. In this case, + // 'instruction_order' must be a topological sort of the set of nodes + // accessible from the root of the computation. + StatusOr EmitComputation( + HloComputation* computation, const string& function_name_prefix, + bool is_entry_computation, + std::vector* instruction_order = nullptr); + + protected: + // + // The following methods implement the DfsHloVisitor interface. + // + // Default action which emits code for most operations. Operations which are + // special in some way are handled explicitly in HandleFoo methods. + Status DefaultAction(HloInstruction* hlo_instruction) override; + + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override; + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override; + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleInfeed(HloInstruction* infeed) override; + Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; + Status HandleParameter(HloInstruction* parameter) override; + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override; + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandlePad(HloInstruction* pad) override; + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice static_operands) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleCall(HloInstruction* call, + tensorflow::gtl::ArraySlice operands, + HloComputation* computation) override; + Status HandleCustomCall(HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) override; + Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, + HloComputation* condition, HloComputation* body) override; + Status FinishVisit(HloInstruction* root) override; + + Status Preprocess(HloInstruction* hlo) override; + Status Postprocess(HloInstruction* visited) override; + + private: + // Private helper to initialize an IR function for the computation. + void InitializeIrFunction(const string& function_name, + bool is_entry_computation); + + // Convenience function to generate a GEP into the profile counter parameter + // which would correspond to the index for a given HLO. + llvm::Value* GetProfileCounterFor(const HloInstruction* hlo); + + // Convenience function to get the IR Value emitted previously for the given + // hlo. Make sure to call it only when you're certain a value *was* emitted - + // if not found, this will log a fatal error. + llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); + + // Convenience function to get an IrArray representing the given hlo. + llvm_ir::IrArray GetIrArrayForOp(const HloInstruction* hlo); + + // Augments IrArray with aliasing information. + void AddAliasingInformationToIrArray(const HloInstruction& hlo, + llvm_ir::IrArray* array) { + alias_analysis_.AddAliasingInformationToIrArray(hlo, array); + } + + // Convenience function to get the IR type matching the given shape. + llvm::Type* IrShapeType(const Shape& shape); + + // Get the llvm::Value* that represents the "retval" argument of the + // computation function being emitted by this emitter. + llvm::Argument* GetResultArgument(); + + // Get the llvm::Value* that represents the "prof_counters" argument of the + // computation function being emitted by this emitter. + llvm::Argument* GetProfileCountersArgument(); + + // Get the xla::ExecutableRunOptions that represents the "run_options" + // argument of the computation function being emitted by this emitter. + llvm::Value* GetExecutableRunOptionsArgument(); + + // Get the llvm::Value* that represents the "temps" argument of the + // computation function being emitted by this emitter. + llvm::Value* GetTempBuffersArgument(); + + // Emits code that computes the address of the given temporary buffer to the + // function. target_shape is the shape of this temporary buffer. + // The returned Value's type is a pointer to element_type. + llvm::Value* EmitTempBufferPointer(BufferAllocation::Index temp_buf_index, + const Shape& target_shape); + + // Emits a function into the current module. This can be used for + // computations embedded inside other computations, such as the + // function that a map operation applies. + StatusOr EmitFunction( + HloComputation* function, // The function to emit. + tensorflow::StringPiece + function_name_suffix); // Used for LLVM IR register names. + + // Methods that emit a function call. + // Parameters: + // function - The LLVM function to call. + // return_shape - The return shape of the HLO computation that was used to + // make the function. Not the same as the return type of the function + // in LLVM, since we use output parameters for the return type. + // element_count - number of elements to return (array form only). + // parameter_addresses - pointers to be passed to the function as + // parameters. + // name - used for LLVM IR register names. + + // Emits a function call, returning a scalar, often an element of a larger + // array. Returns a Value for the scalar element returned by the function. + llvm::Value* EmitElementFunctionCall( + llvm::Function* function, const Shape& return_shape, + tensorflow::gtl::ArraySlice parameter_addresses, + tensorflow::StringPiece name); + + // Array function call emitter. Stores the function's result into a supplied + // buffer. + // Parameters: + // function - The LLVM function to call. + // parameter_addresses - pointers to be passed to the function as + // parameters. + // return_value - pointer to a buffer where the call result is stored. + + void EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value, tensorflow::StringPiece name); + + // Array function call emitter. Returns a Value for the function's return + // value buffer address. The return value buffer is alloca'ed by this + // function. + llvm::Value* EmitArrayFunctionCall( + llvm::Function* function, const Shape& return_shape, int64 element_count, + tensorflow::gtl::ArraySlice parameter_addresses, + tensorflow::StringPiece name); + + // Verifies that the element types of all of the given operand instructions + // match and are of one of the given supported types. + Status ElementTypesSameAndSupported( + const HloInstruction& instruction, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice supported_types); + + // Emit IR to perform a computation for every element in the given target op. + // This produces a series of nested loops (one for each dimension of the op's + // shape). The body of the inner-most loop is provided by the body_emitter + // function. + // + // TODO(jingyue): target_op should be a `const HloInstruction*`. + Status EmitTargetElementLoop( + HloInstruction* target_op, + const llvm_ir::ElementGenerator& element_generator); + + // Emits a memcpy from the source instruction's result value to the + // destination's. Both source and destination must have an entry in the + // emitted_value_ table. + Status EmitMemcpy(const HloInstruction& source, + const HloInstruction& destination); + + // Emit IR to compute the target address of the buffer for the given op. + // The returned Value is a pointer to a IR type that represents the op's + // element type. + StatusOr EmitTargetAddressForOp(const HloInstruction* op); + + // Structurizes "array_elements" into an MD array that represents "shape". + // This is a recursive function, and "dimension_index" indicates the index of + // the current dimension that the function is considering (0 means the + // most-minor dimension). + llvm::Constant* CreateInitializerForConstantArray( + const std::vector& array_elements, const Shape& shape, + int64 dimension_index); + + // Name of the computation entry function. This function serves as the + // top-level "main" of the computation and will be invoked by the JIT. + string entry_function_name_; + + // Assignment of the temporary buffers needed by the computation and their + // shape information. + const BufferAssignment& assignment_; + + // The LLVM module into which IR will be emitted. + llvm::Module* module_; + + // The target architecture. + llvm::Triple::ArchType arch_type_; + + // Used to produce unique names for generated functions. + NameUniquer name_uniquer_; + + // Map containing all previously emitted computations. + std::map emitted_functions_; + + // Map containing all previously emitted thread-local temporary buffers. + std::map, + llvm::AllocaInst*> + thread_local_buffers_; + + // The following fields track the IR emission state. According to LLVM memory + // management rules, their memory is owned by the module. + llvm::Function* compute_function_; + llvm::IRBuilder<> ir_builder_; + + // Maps HLOs to their index into the profile counter array. + const std::unordered_map* hlo_to_profile_idx_; + + // Maps HLOs to Values emitted for them. + std::unordered_map emitted_value_; + + llvm_ir::AliasAnalysis alias_analysis_; + + // This struct contains all the state needed to emit instructions for + // profiling a computation. + class ProfilingState { + public: + ProfilingState() + : is_entry_computation_(false), + use_rdtscp_(false), + prof_counters_(nullptr) {} + ProfilingState(bool is_entry_computation, bool use_rdtscp, + llvm::Argument* prof_counters) + : is_entry_computation_(is_entry_computation), + use_rdtscp_(use_rdtscp), + prof_counters_(prof_counters) {} + + // Record the cycle counter before an HLO executes. + void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo); + // Record the number of cycles it took for an HLO to execute. + void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo, + llvm::Value* prof_counter); + // Record the number of cycles it took for the entire computation to + // execute. + void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder, + llvm::Value* prof_counter); + + // Convenience function to generate a call to an intrinsic which reads the + // CPU cycle counter. + llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder); + + // Store the cycle counter delta to the per-HLO profile counter. + void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder, + llvm::Value* prof_counter, llvm::Value* cycle_end, + llvm::Value* cycle_start); + + private: + // Is this IrEmitter for a top-level computation? + bool is_entry_computation_; + + // Should we use the x86-specific rdtscp or the generic readcyclecounter + // intrinsic? + bool use_rdtscp_; + + // The argument which corresponds to the profile counter buffer. + llvm::Argument* prof_counters_; + + // The first read cycle counter in the program. + llvm::Value* first_read_cycle_start_ = nullptr; + + // The last read cycle counter in the program. + llvm::Value* last_read_cycle_end_ = nullptr; + + // An alloca used to hold the output of the aux value returned by the rdtscp + // intrinsic. + llvm::Value* aux_i8ptr_ = nullptr; + + // Maps HLOs to the value the cycle counter contained right before the HLO + // began to execute. + std::unordered_map cycle_starts_; + }; + + ProfilingState profiling_state_; + + // Given a load instruction and a shape or buffer size, annotate the load's + // result with the alignment required by the shape or size. + void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape); + void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size); + + // Given a load instruction and a shape or buffer size, annotate the load's + // result with the dereferenceable bytes required by the shape / buffer size. + void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, + const Shape& shape); + void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, + int64 buffer_size); + + // Calculate the alignment of a buffer allocated for a given shape. + int MinimumAlignmentForShape(const Shape& shape); + + // Calculate the alignment of a buffer allocated for a given primitive type. + int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); + + // Calculate the alignment of a buffer with a particular size. + int MinimumAlignmentForBufferSize(int64 buffer_size); + + // Returns the number of bytes within the shape. + int64 ByteSizeOf(const Shape& shape) const; + + const HloModuleConfig& hlo_module_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc new file mode 100644 index 0000000000..136a8c9641 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" + +#include + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace cpu { + +Status CpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + auto row_major_shape = [](const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; + }; + const HloComputation* computation = constraints->computation(); + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction)) { + const HloInstruction* convolution = instruction.get(); + const HloInstruction* lhs_instruction = convolution->operand(0); + const HloInstruction* rhs_instruction = convolution->operand(1); + + // In order to implement `convolution` with Eigen convolution, the layouts + // of the input, filter, and output need to be row-major. + // + // These constraints are not hard constraints. Ideally, we should decide + // which layouts to choose according to some cost model. + Shape output_shape(row_major_shape(convolution->shape())); + Shape input_shape(row_major_shape(lhs_instruction->shape())); + Shape filter_shape(row_major_shape(rhs_instruction->shape())); + + // Set layouts of the instructions' shapes. + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(input_shape, convolution, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(filter_shape, convolution, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, convolution)); + } else if (PotentiallyImplementedAsEigenDot(*instruction)) { + const HloInstruction* dot = instruction.get(); + const HloInstruction* lhs_instruction = dot->operand(0); + const HloInstruction* rhs_instruction = dot->operand(1); + + // In order to implement `dot` with Eigen dot, the layouts of the lhs, + // rhs, and output need to be row-major. + // + // These constraints are not hard constraints. Ideally, we should decide + // which layouts to choose according to some cost model. + Shape output_shape(row_major_shape(dot->shape())); + Shape lhs_shape(row_major_shape(lhs_instruction->shape())); + Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + + // Set layouts of the instructions' shapes. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); + } else { + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + // Skip operands which already have a constraint. + if (constraints->OperandLayout(instruction.get(), operand_no) != + nullptr) { + continue; + } + // Skip over forwarded operands. + if (constraints->OperandBufferForwarded(instruction.get(), + operand_no)) { + continue; + } + Shape operand_shape( + row_major_shape(instruction->operand(operand_no)->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + operand_shape, instruction.get(), operand_no)); + } + // Skip over the root instruction for the top-level computation. + if (computation->parent()->entry_computation() == computation && + computation->root_instruction() == instruction.get()) { + continue; + } + // Skip instructions which don't produce array shapes (tuples, opaque, + // etc.). + if (!ShapeUtil::IsArray(instruction->shape())) { + continue; + } + tensorflow::gtl::ArraySlice buffers = + constraints->points_to_analysis() + .GetPointsToSet(instruction.get()) + .element({}); + // Only force the layout if the instruction hasn't been otherwise assigned + // one or has ambiguous aliasing properties. + if (buffers.size() == 1 && + buffers[0]->instruction() == instruction.get() && + constraints->BufferLayout(*buffers[0]) == nullptr) { + Shape output_shape(row_major_shape(instruction->shape())); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction.get())); + } + } + } + return tensorflow::Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/layout_assignment.h new file mode 100644 index 0000000000..4fd8d68dd6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ + +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace cpu { + +// CPU-specific layout assignment pass which preassigns layouts to satisfy +// layout constraints for operands and results of library calls. +class CpuLayoutAssignment : public LayoutAssignment { + public: + explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout) + : LayoutAssignment(entry_computation_layout) {} + ~CpuLayoutAssignment() override {} + + protected: + Status AddBackendConstraints(LayoutConstraints* constraints) override; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc new file mode 100644 index 0000000000..268a36a660 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -0,0 +1,365 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace cpu { + +ParallelCpuExecutable::ParallelCpuExecutable( + std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + std::unique_ptr module_config, + std::unique_ptr> function_names, + std::unordered_map hlo_to_profile_idx, + std::unordered_map> + aligned_constants) + : Executable(std::move(hlo_module), std::move(module_config)), + jit_(std::move(jit)), + assignment_(std::move(assignment)), + functions_names_(std::move(function_names)), + hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), + aligned_constants_(std::move(aligned_constants)) {} + +// Type of the computation function we expect in the JIT. +using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, + uint64*); + +// Given a pointer to an output buffer (following the CPU JIT calling +// conventions), mark addresses that are "live". The initial pointer itself is +// trivially live. If the shape of the buffer is a tuple, this analysis looks +// into the tuple's elements and marks them live as well (since tuples keep +// pointers to buffers) and also works recursively. +// address is an in-memory buffer address that contains some runtime XLA object. +// shape is its shape. marked_addresses is the set of live addresses to +// populate. +static void MarkLiveAddressesInOutput( + const void* address, const Shape& shape, + std::unordered_set* marked_addresses) { + marked_addresses->insert(address); + const uintptr_t* address_buffer = static_cast(address); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const uintptr_t* element_address = address_buffer + i; + const void* element = reinterpret_cast(*element_address); + MarkLiveAddressesInOutput( + element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); + } + } +} + +StatusOr +ParallelCpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); + if (!arguments.empty()) { + VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); + } + + // Allocate the temporary buffers required for the computation. + se::StreamExecutor* stream_executor = stream->parent(); + int device_ordinal = stream_executor->device_ordinal(); + int64 buffer_count = assignment_->Allocations().size(); + VLOG(3) << "temp buffer count: " << buffer_count; + + std::vector device_allocations; + for (BufferAllocation::Index i = 0; i < buffer_count; ++i) { + auto& allocation = assignment_->GetAllocation(i); + if (allocation.is_entry_computation_parameter()) { + // Buffers do not need to be allocated for parameters. + device_allocations.push_back(se::DeviceMemoryBase(nullptr)); + continue; + } + + if (allocation.is_thread_local()) { + // Buffers do not need to be allocated for thread-local temporaries. + device_allocations.push_back(se::DeviceMemoryBase(nullptr)); + continue; + } + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase device_allocation, + memory_allocator->Allocate(device_ordinal, allocation.size())); + + // TODO(eliben): refactor this into buffer_assignment + if (VLOG_IS_ON(3)) { + VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size() + << " bytes for allocation #" << i << " [" + << device_allocation.opaque() << "]"; + std::vector parts; + for (const LogicalBuffer* buffer : allocation.assigned_buffers()) { + parts.push_back(buffer->ToString()); + } + VLOG(3) << " " << tensorflow::str_util::Join(parts, ", "); + } + + device_allocations.push_back(device_allocation); + // Since the output buffer and all the temporary buffers were written into + // by the JITed code, msan has no way of knowing their memory was + // initialized. Mark them initialized so that msan doesn't flag loads from + // these buffers. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(), + allocation.size()); + } + + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment_->GetUniqueTopLevelOutputAllocation()); + BufferAllocation::Index result_index = result_allocation->index(); + VLOG(3) << "result index: " << result_index; + + // Allocate profiling counters for each hlo instruction that we would like to + // profile. Allocate an additional profile counter for the entire + // computation. + std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + + std::vector buffer_pointers; + for (auto& device_allocation : device_allocations) { + buffer_pointers.push_back(device_allocation.opaque()); + } + + // Resolve functions for all the HLO instructions ahead of time. + std::map functions; + for (auto& entry : *functions_names_) { + tensorflow::mutex_lock lock(jit_mutex_); + HloInstruction* instruction = entry.first; + llvm::JITSymbol sym = jit_->FindSymbol(entry.second); + TF_RET_CHECK(sym); + InsertOrDie(&functions, instruction, + reinterpret_cast(sym.getAddress())); + } + + // Map containing pointers to result buffers for each instruction. + std::map results; + + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + + std::list pending; + + // Call the function for each HLO instruction in topological order. + for (auto* instruction : + module().entry_computation()->MakeInstructionPostOrder()) { + // Parameters and constants have no functions associated with them. Instead + // just copy the existing buffer into the map containing instruction + // results.. + if (instruction->opcode() == HloOpcode::kParameter) { + InsertOrDie(&results, instruction, + arguments[instruction->parameter_number()].opaque()); + } else if (instruction->opcode() == HloOpcode::kConstant) { + unsigned char* aligned_data = + FindOrDie(aligned_constants_, instruction).get(); + InsertOrDie(&results, instruction, aligned_data); + } else { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall); + pending.push_back(instruction); + } + } + + auto* temps_array = buffer_pointers.data(); + auto* profile_counters_array = profile_counters.data(); + auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool()); + tensorflow::mutex completion_queue_lock; + tensorflow::condition_variable completion_queue_cv; + std::deque completion_queue; + int64 instructions_in_flight = 0; + while (!pending.empty() || instructions_in_flight > 0) { + auto pending_it = pending.begin(); + while (pending_it != pending.end()) { + HloInstruction* instruction = *pending_it; + // Skip pending instructions whose operands aren't ready. + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), + [&](HloInstruction* operand) { + return !ContainsKey(results, operand); + })) { + ++pending_it; + continue; + } + + TF_ASSIGN_OR_RETURN( + const BufferAllocation* result_allocation, + assignment_->GetUniqueTopLevelAllocation(instruction)); + + void* result_buffer = buffer_pointers[result_allocation->index()]; + // We cannot use a move-only RAII type like std::unique_ptr because the + // list of operands is allocated on the main thread and transferred to the + // worker via the lambda passed to enqueue_function. In order for the + // lambda to take ownership, we would need to use generalized lambda + // capture which is a feature new to C++14. + auto operand_buffers = new const void*[instruction->operand_count()]; + std::transform(instruction->operands().begin(), + instruction->operands().end(), operand_buffers, + [&results](HloInstruction* operand) { + return FindOrDie(results, operand); + }); + auto function = FindOrDie(functions, instruction); + // The thread pool entry takes ownership of |operand_buffers|. + thread_pool->Schedule([instruction, &completion_queue, + &completion_queue_lock, &completion_queue_cv, + result_buffer, run_options, operand_buffers, + temps_array, profile_counters_array, function] { + function(result_buffer, run_options, operand_buffers, temps_array, + profile_counters_array); + delete[] operand_buffers; + // Push the completed HLO instruction on the queue, the main thread + // will pop it off and potentially launch more work which uses the + // result. + { + tensorflow::mutex_lock l(completion_queue_lock); + completion_queue.push_back(instruction); + completion_queue_cv.notify_all(); + } + }); + + ++instructions_in_flight; + pending_it = pending.erase(pending_it); + } + // Wait for a completed HLO instruction to be present in the queue. We will + // pop it out of the queue and make the result available to its users. + HloInstruction* instruction; + do { + tensorflow::mutex_lock l(completion_queue_lock); + if (completion_queue.empty()) { + completion_queue_cv.wait(l); + } + if (!completion_queue.empty()) { + instruction = completion_queue.front(); + completion_queue.pop_front(); + break; + } + } while (1); + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment_->GetUniqueTopLevelAllocation(instruction)); + void* result_buffer = buffer_pointers[result_allocation->index()]; + InsertOrDie(&results, instruction, result_buffer); + --instructions_in_flight; + } + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + + { + tensorflow::mutex_lock lock(mutex_); + double nanoseconds = (end_micros - start_micros) * 1000.0; + execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + // The last profile counter is used for the computation as a whole. + execution_profile_.set_compute_cycle_count(profile_counters.back()); + } + if (hlo_execution_profile != nullptr) { + hlo_execution_profile->set_total_cycles_executed(profile_counters.back()); + + for (auto hlo_prof_idx : hlo_to_profile_idx_) { + const HloInstruction* hlo = hlo_prof_idx.first; + uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; + hlo_execution_profile->AddProfileResult(hlo, cycles_taken); + } + } + + // Mark the buffers that are actually live (used in the output) when the + // computation finishes executing. + std::unordered_set marked_addresses; + MarkLiveAddressesInOutput(device_allocations[result_index].opaque(), + result_shape(), &marked_addresses); + + VLOG(3) << "Live addresses in output marking found " + << marked_addresses.size() << " addresses:\n" + << tensorflow::str_util::Join( + marked_addresses, ", ", [](string* out, const void* address) { + tensorflow::strings::StrAppend( + out, tensorflow::strings::Printf("%p", address)); + }); + + // Computation is done - deallocate temp buffers. Keep those marked + // live because they are referenced by the output of the computation + // and are needed by the service. They will be deallocated by the + // service. + for (auto i = 0; i < device_allocations.size(); ++i) { + auto alloc = device_allocations[i]; + if (marked_addresses.count(alloc.opaque()) == 0 && + alloc.opaque() != nullptr) { + VLOG(3) << "ParallelCpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR(memory_allocator->Deallocate(device_ordinal, &alloc)); + } + } + + return device_allocations[result_index]; +} + +StatusOr> ParallelCpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + return Unimplemented( + "ParallelCpuExecutable not supported yet with LocalService execution"); +} + +Status ParallelCpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) { + return Unimplemented( + "preallocated result buffer not supported with ParallelCpuExecutable"); +} + +StatusOr +ParallelCpuExecutable::ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments) { + // TODO(b/30671675): Implement asynchronous execution mode. + return Unimplemented( + "Asynchronous execution on stream is not yet supported on CPU."); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h new file mode 100644 index 0000000000..51ec9e5a74 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace cpu { + +// CPU-targeting parallel implementation of the XLA Executable interface. +// +// Wraps a JIT-ed object that can be executed "on device". We JIT for the host +// architecture, so JIT-ed code and host code share the same ABI. +class ParallelCpuExecutable : public Executable { + public: + ParallelCpuExecutable( + std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + std::unique_ptr module_config, + std::unique_ptr> instruction_functions, + std::unordered_map hlo_to_profile_idx, + std::unordered_map> + aligned_constants); + ~ParallelCpuExecutable() override {} + + StatusOr ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr> ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) override; + + Status ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments) override; + + // This should be called after set_ir_module_string. + const string& ir_module_string() const { return ir_module_string_; } + + void set_ir_module_string(const string& ir_module_string) { + ir_module_string_ = ir_module_string; + } + + private: + // The JIT containing compiled modules. + tensorflow::mutex jit_mutex_; + std::unique_ptr jit_ GUARDED_BY(jit_mutex_); + + // Buffer assignment for the buffers we need to allocate. + std::unique_ptr assignment_; + + // The LLVM IR, in string format, of the unoptimized module generated for this + // ParallelCpuExecutable. We save a string instead of an llvm::Module* because + // leaving llvm::Module* in a singleton can cause the heap checker to emit + // false positives. + string ir_module_string_; + + // Map containing the JITted function names for each HLO instruction. + std::unique_ptr> functions_names_; + + // Maps HLOs to their index into the profile counter array. + const std::unordered_map hlo_to_profile_idx_; + + // Map from HLO Constant instructions to a pointer to their literal data. + // The data stored in the protocol buffer might be insufficiently aligned, + // we create a sufficiently aligned copy and store it in this map. + std::unordered_map> + aligned_constants_; + + TF_DISALLOW_COPY_AND_ASSIGN(ParallelCpuExecutable); +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc new file mode 100644 index 0000000000..c2f64eb27a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" + +#define EIGEN_USE_THREADS + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, + int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + tensorflow::xla::EigenConvF32Impl( + *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, + input_rows, input_cols, input_channels, kernel_rows, kernel_cols, + kernel_channels, kernel_filters, output_rows, output_cols, row_stride, + col_stride, padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h new file mode 100644 index 0000000000..05ae094691 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenConvF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h new file mode 100644 index 0000000000..02f45fee0f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/eigen_spatial_convolutions.h" +#include "tensorflow/core/platform/types.h" + +// 'tensorflow' namespace is used so that int64 and other types don't require +// qualification. +namespace tensorflow { +namespace xla { + +template +void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs, + float* rhs, int64 input_batch, int64 input_rows, + int64 input_cols, int64 input_channels, int64 kernel_rows, + int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, + int64 output_cols, int64 row_stride, int64 col_stride, + int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, + int64 rhs_row_dilation, int64 rhs_col_dilation) { + const Eigen::TensorMap, + Eigen::Aligned> + input(lhs, input_batch, input_rows, input_cols, input_channels); + + const Eigen::TensorMap, + Eigen::Aligned> + kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters); + + Eigen::TensorMap, Eigen::Aligned> + output(out, input_batch, output_rows, output_cols, kernel_filters); + + Eigen::array, 1> contract_dims; + contract_dims[0] = Eigen::IndexPair(1, 0); + + // Molds the output of the patch extraction code into a 2d tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + Eigen::DSizes pre_contract_dims; + pre_contract_dims[0] = output_cols * output_rows * input_batch; + pre_contract_dims[1] = kernel_channels * kernel_cols * kernel_rows; + + // Molds the output of the contraction into the shape expected by the user: + Eigen::DSizes post_contract_dims; + post_contract_dims[0] = input_batch; + post_contract_dims[1] = output_rows; + post_contract_dims[2] = output_cols; + post_contract_dims[3] = kernel_filters; + + Eigen::DSizes kernel_dims; + kernel_dims[0] = kernel_channels * kernel_cols * kernel_rows; + kernel_dims[1] = kernel_filters; + + // The row and column dimensions must be flipped when passed to Eigen. + output.device(device) = + input + .extract_image_patches(kernel_cols, kernel_rows, col_stride, + row_stride, rhs_col_dilation, rhs_row_dilation, + lhs_col_dilation, lhs_row_dilation, + padding_left, padding_right, padding_top, + padding_bottom, 0.0f) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); +} + +} // namespace xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc new file mode 100644 index 0000000000..677080a862 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +namespace { + +template +void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + + int64 lhs_rows = m; + int64 lhs_cols = k; + if (transpose_lhs) { + std::swap(lhs_rows, lhs_cols); + } + + int64 rhs_rows = k; + int64 rhs_cols = n; + if (transpose_rhs) { + std::swap(rhs_rows, rhs_cols); + } + + const Eigen::TensorMap, Eigen::Aligned> A( + lhs, lhs_rows, lhs_cols); + const Eigen::TensorMap, Eigen::Aligned> B( + rhs, rhs_rows, rhs_cols); + Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + int lhs_contract_dim = transpose_lhs ? 0 : 1; + int rhs_contract_dim = transpose_rhs ? 1 : 0; + const Eigen::array dims( + DimPair(lhs_contract_dim, rhs_contract_dim)); + + // Matrix multiply is a special case of the "contract" operation where + // the contraction is performed along dimension 1 of the lhs and dimension + // 0 of the rhs. + C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims); +} + +} // namespace + +void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, + float* lhs, float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} + +void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, + double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h new file mode 100644 index 0000000000..fdb644651d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Performs a multi-threaded matrix multiplication using Eigen. 'lhs' and 'rhs' +// are pointers to buffers containing input matrices in column-major +// order. 'out' is a pointer to a buffer sufficiently large to hold the result +// of the operation. Following standard nomenclature: lhs is m x k, +// rhs is k x n, and out is m x n. +extern void __xla_cpu_runtime_EigenMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +extern void __xla_cpu_runtime_EigenMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc new file mode 100644 index 0000000000..d0b0e11ac0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedConvF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, + int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { + tensorflow::xla::EigenConvF32Impl( + Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, + input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, col_stride, + padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h new file mode 100644 index 0000000000..8ae1a42149 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedConvF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc new file mode 100644 index 0000000000..384a978873 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +namespace { + +template +void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + int64 lhs_rows = m; + int64 lhs_cols = k; + if (transpose_lhs) { + std::swap(lhs_rows, lhs_cols); + } + + int64 rhs_rows = k; + int64 rhs_cols = n; + if (transpose_rhs) { + std::swap(rhs_rows, rhs_cols); + } + + const Eigen::TensorMap, Eigen::Aligned> A( + lhs, lhs_rows, lhs_cols); + const Eigen::TensorMap, Eigen::Aligned> B( + rhs, rhs_rows, rhs_cols); + Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + int lhs_contract_dim = transpose_lhs ? 0 : 1; + int rhs_contract_dim = transpose_rhs ? 1 : 0; + const Eigen::array dims( + DimPair(lhs_contract_dim, rhs_contract_dim)); + + // Matrix multiply is a special case of the "contract" operation where + // the contraction is performed along dimension 1 of the lhs and dimension + // 0 of the rhs. + C = A.contract(B, dims); +} + +} // namespace + +void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} + +void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h new file mode 100644 index 0000000000..029eb95142 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Performs a single-threaded matrix multiplication using Eigen. 'lhs' and 'rhs' +// are pointers to buffers containing input matrices in column-major order. +// 'out' is a pointer to a buffer sufficiently large to hold the result of the +// operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and +// out is m x n. +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc new file mode 100644 index 0000000000..c938a03df7 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); + + // Transfer parameters. + std::unique_ptr param0_literal = + xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + std::unique_ptr param0_data = + client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_data = + client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + // Build computation. + xla::ComputationBuilder builder(client, ""); + auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto add = builder.Add(p1, p0, {0}); + + xla::StatusOr computation_status = builder.Build(); + xla::Computation computation = computation_status.ConsumeValueOrDie(); + + // Execute and transfer result of computation. + xla::ExecutionProfile profile; + xla::StatusOr> result = + client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/&profile); + std::unique_ptr actual = result.ConsumeValueOrDie(); + + LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", + profile.compute_time_ns()); + LOG(INFO) << xla::LiteralUtil::ToString(*actual); + + return 0; +} diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc new file mode 100644 index 0000000000..7754c556a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" + +#include +#include +#include +#include +#include + +#include "external/llvm/include/llvm/IR/Mangler.h" +#include "external/llvm/include/llvm/Support/CodeGen.h" +#include "external/llvm/include/llvm/Support/Host.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace cpu { +namespace { + +// Converts a symbol 'name' into the form expected by dlsym(). +std::string CanonicalizeSymbol(const std::string &name) { +#if defined(__APPLE__) + // On Mac OS X, dlsym() expects names not to be prefixed with a leading + // underscore. + if (!name.empty() && name.front() == '_') { + return name.substr(1); + } +#endif + return name; +} + +// A simple SymbolResolver that delegates to the host dynamic linker. +struct SimpleResolver : public llvm::JITSymbolResolver { + llvm::JITSymbol findSymbol(const std::string &name) override { + void *func_addr = nullptr; + + std::string canonical_name = CanonicalizeSymbol(name); + if (canonical_name == runtime::kEigenMatmulF32SymbolName) { + func_addr = reinterpret_cast(__xla_cpu_runtime_EigenMatMulF32); + } else if (canonical_name == + runtime::kEigenSingleThreadedMatmulF32SymbolName) { + func_addr = reinterpret_cast( + __xla_cpu_runtime_EigenSingleThreadedMatMulF32); + } else if (canonical_name == runtime::kEigenConvF32SymbolName) { + func_addr = reinterpret_cast(__xla_cpu_runtime_EigenConvF32); + } else if (canonical_name == + runtime::kEigenSingleThreadedConvF32SymbolName) { + func_addr = reinterpret_cast( + __xla_cpu_runtime_EigenSingleThreadedConvF32); + } else if (canonical_name == + runtime::kAcquireInfeedBufferForDequeueSymbolName) { + func_addr = reinterpret_cast( + __xla_cpu_runtime_AcquireInfeedBufferForDequeue); + } else if (canonical_name == + runtime::kReleaseInfeedBufferAfterDequeueSymbolName) { + func_addr = reinterpret_cast( + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue); + } else if (canonical_name == runtime::kExpV4F32) { + func_addr = reinterpret_cast(runtime::ExpV4F32); + } else if (canonical_name == runtime::kExpV8F32) { + func_addr = reinterpret_cast(runtime::ExpV8F32); + } else if (canonical_name == runtime::kLogV4F32) { + func_addr = reinterpret_cast(runtime::LogV4F32); + } else if (canonical_name == runtime::kLogV8F32) { + func_addr = reinterpret_cast(runtime::LogV8F32); + } else if (canonical_name == runtime::kTanhV4F32) { + func_addr = reinterpret_cast(runtime::TanhV4F32); + } else if (canonical_name == runtime::kTanhV8F32) { + func_addr = reinterpret_cast(runtime::TanhV8F32); + } else { + func_addr = dlsym(RTLD_DEFAULT, canonical_name.c_str()); + } + + if (func_addr == nullptr) { + return nullptr; + } + llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), + llvm::JITSymbolFlags::None); + return symbol_info; + } + llvm::JITSymbol findSymbolInLogicalDylib(const std::string &name) override { + return nullptr; + } +}; + +llvm::SmallVector DetectMachineAttributes() { + llvm::SmallVector result; + llvm::StringMap host_features; + if (llvm::sys::getHostCPUFeatures(host_features)) { + for (auto &feature : host_features) { + if (feature.second) { + result.push_back(feature.first()); + } + } + } + return result; +} + +CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { + CompilerFunctor::VectorIntrinsics intrinsics; + intrinsics.sse_intrinsics = (&runtime::ExpV4F32 != nullptr); + intrinsics.avx_intrinsics = (&runtime::ExpV8F32 != nullptr); + return intrinsics; +} + +} // namespace + +SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, + llvm::CodeGenOpt::Level opt_level) + : target_machine_( + CHECK_NOTNULL(llvm::EngineBuilder() + .setTargetOptions(target_options) + .setOptLevel(opt_level) + .selectTarget( + /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", + /*MCPU=*/llvm::sys::getHostCPUName(), + /*MAttrs=*/DetectMachineAttributes()))), + disassembler_(*target_machine_), + data_layout_(target_machine_->createDataLayout()), + compile_layer_(object_layer_, + CompilerFunctor(target_machine_.get(), &disassembler_, + opt_level, GetAvailableIntrinsics())) {} + +SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( + std::unique_ptr module) { + // The Orc API adds a whole iterable "set" of modules, so we wrap the module + // in a vector. + std::vector> module_set; + module_set.push_back(std::move(module)); + auto handle = compile_layer_.addModuleSet( + std::move(module_set), MakeUnique(), + MakeUnique()); + module_handles_.push_back(handle); + return handle; +} + +void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) { + module_handles_.erase( + std::remove(module_handles_.begin(), module_handles_.end(), handle)); + compile_layer_.removeModuleSet(handle); +} + +llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string &name) { + std::string mangled_name; + { + llvm::raw_string_ostream mangled_name_stream(mangled_name); + llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, data_layout_); + } + + // Resolve symbol from last module to first, allowing later redefinitions of + // symbols shadow earlier ones. + for (auto &handle : + llvm::make_range(module_handles_.rbegin(), module_handles_.rend())) { + if (auto symbol = + compile_layer_.findSymbolIn(handle, mangled_name, + /*ExportedSymbolsOnly=*/true)) { + return symbol; + } + } + + return nullptr; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h new file mode 100644 index 0000000000..9d1c842e0f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ + +#include +#include +#include + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "external/llvm/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace cpu { + +// Simplified LLVM JIT based on the new Orc API. +// +// This class wraps Orc's functionality into a single interface that only +// exposes what we need for XLA. +// +// Supports JIT-ing multiple modules but without cross-module linking. +// Implements eager compilation - the module is lowered to binary as soon as +// it's added to the JIT. +class SimpleOrcJIT { + public: + using ObjLayerT = llvm::orc::ObjectLinkingLayer<>; + using CompileLayerT = llvm::orc::IRCompileLayer; + using ModuleHandleT = CompileLayerT::ModuleSetHandleT; + + // Create a new JIT, targeting the host architecture. + // The |target_options| parameter allows customization of certain code + // generation properties of the TargetMachine (whether or not float point math + // can be reassociated, etc.). + // The |opt_level| parameter controls the optimization level of the code + // generator. + SimpleOrcJIT(const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level); + + // Data layout this JIT was created with. + const llvm::DataLayout& data_layout() const { return data_layout_; } + + // Target triple (host) this JIT was created with. + const llvm::Triple& target_triple() const { + return target_machine_->getTargetTriple(); + } + + // Add a module to the JIT. Returns an opaque handle that can be used to later + // remove this module. + ModuleHandleT AddModule(std::unique_ptr module); + + // Remove a module from the JIT and free the memory associated with it. + void RemoveModule(ModuleHandleT handle); + + // Get the runtime address of the compiled symbol whose name is given. Returns + // nullptr if the symbol cannot be found. + llvm::JITSymbol FindSymbol(const std::string& name); + + private: + std::vector module_handles_; + std::unique_ptr target_machine_; + const Disassembler disassembler_; + const llvm::DataLayout data_layout_; + ObjLayerT object_layer_; + CompileLayerT compile_layer_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc new file mode 100644 index 0000000000..423ec29fdc --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu_transfer_manager.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +namespace { + +class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { + public: + explicit CpuInfeedBuffer(int32 length) + : length_(length), + buffer_(new char[length]), + device_memory_(buffer_, length_) {} + ~CpuInfeedBuffer() override { delete[] buffer_; } + + int32 length() override { return length_; } + void* data() override { return buffer_; } + void Done() override { delete this; } + + se::DeviceMemoryBase* device_memory() { return &device_memory_; } + + private: + int32 length_; + char* buffer_; + se::DeviceMemoryBase device_memory_; +}; + +} // namespace + +CpuTransferManager::CpuTransferManager() + : GenericTransferManager(se::host::kHostPlatformId) {} + +Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, + const Literal& literal) { + const Shape& shape = literal.shape(); + VLOG(2) << "transferring literal shape to infeed: " + << ShapeUtil::HumanString(shape); + + // TODO(b/31381668) handle tuples. + if (ShapeUtil::IsTuple(shape)) { + return Unimplemented("Infeed with a tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + cpu::runtime::InfeedManager* infeed_manager = + cpu::runtime::GetInfeedManager(); + + int64 size = GetByteSizeRequirement(shape); + if (size > std::numeric_limits::max()) { + return Unimplemented("Infeed shape is too large: %s needs %lld bytes", + ShapeUtil::HumanString(literal.shape()).c_str(), size); + } + int32 size_32 = static_cast(size); + CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + executor, /*size=*/size, /*source=*/LiteralUtil::InternalData(literal), + queued_buffer->device_memory())); + + infeed_manager->EnqueueBuffer(queued_buffer); + + return Status::OK(); +} + +} // namespace xla + +static xla::TransferManager* CreateCpuTransferManager() { + return new xla::CpuTransferManager(); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager(se::host::kHostPlatformId, + &CreateCpuTransferManager); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h new file mode 100644 index 0000000000..727462252d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ + +#include + +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// An implementation of the XLA GenericTransferManager that +// handles CPU-specific infeed. +class CpuTransferManager : public GenericTransferManager { + public: + CpuTransferManager(); + ~CpuTransferManager() override {} + + Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, + const Literal& literal) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc new file mode 100644 index 0000000000..1bef4e2b8c --- /dev/null +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" + +#include + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( + perftools::gputools::Platform* platform, + tensorflow::gtl::ArraySlice + stream_executors) + : DeviceMemoryAllocator(platform), + stream_executors_(stream_executors.begin(), stream_executors.end()) {} + +StatusOr +StreamExecutorMemoryAllocator::Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) { + if (size == 0) { + return perftools::gputools::DeviceMemoryBase(nullptr, 0); + } + TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, + GetStreamExecutor(device_ordinal)); + return stream_executor->AllocateArray(size); +} + +tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( + int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) { + if (!mem->is_null()) { + TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, + GetStreamExecutor(device_ordinal)); + // We make a local copy of 'mem' so the original is not zeroed out by the + // Deallocate() call below. This gives us a better chance of + // catching double-free bugs, since Deallocate silently succeeds for null + // values. + perftools::gputools::DeviceMemoryBase mem_copy(*mem); + stream_executor->Deallocate(&mem_copy); + } + return tensorflow::Status::OK(); +} + +StatusOr +StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) { + if (device_ordinal < 0) { + return InvalidArgument("device ordinal value (%d) must be non-negative", + device_ordinal); + } + if (device_ordinal >= stream_executors_.size()) { + return InvalidArgument( + "device ordinal value (%d) >= number of devices (%zu)", device_ordinal, + stream_executors_.size()); + } + if (stream_executors_[device_ordinal] == nullptr) { + return NotFound("Device %s:%d present but not supported", + platform()->Name().c_str(), device_ordinal); + } + return stream_executors_[device_ordinal]; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h new file mode 100644 index 0000000000..461cc818bf --- /dev/null +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Interface for device memory allocators used within the XLA service. An +// allocator is responsible for allocating memory on all devices of a particular +// platform. +class DeviceMemoryAllocator { + public: + // Parameter platform indicates which platform the allocator allocates memory + // on. Must be non-null. + explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform) + : platform_(platform) {} + virtual ~DeviceMemoryAllocator() {} + + // 'retry_on_failure': If false, and the first attempt to allocate the memory + // fails, the allocation should return immediately without retrying. + // An example use case is optional scratch spaces where a failure + // has only performance impact. + // Allocate() should return a null pointer for a size-0 allocation. + // Deallocate() must be a no-op for null pointers. + virtual StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure = true) = 0; + virtual tensorflow::Status Deallocate( + int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0; + + // Return the platform that the allocator allocates memory on. + const perftools::gputools::Platform* platform() const { return platform_; } + + protected: + const perftools::gputools::Platform* platform_; +}; + +// Default memory allocator for a platform which uses +// StreamExecutor::Allocate/Deallocate. +class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { + public: + StreamExecutorMemoryAllocator( + perftools::gputools::Platform* platform, + tensorflow::gtl::ArraySlice + stream_executors); + + StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure = true) override; + tensorflow::Status Deallocate( + int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + + private: + StatusOr GetStreamExecutor( + int device_ordinal); + + // A vector indexed by device ordinal of StreamExecutors for each device of + // the allocator's platform type. If an element is nullptr, then the device + // with the respective device ordinal is not supported by XLA. + std::vector stream_executors_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc new file mode 100644 index 0000000000..5b29686100 --- /dev/null +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo, + HloOpcode opcode, + HloInstruction* operand) { + return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", + HloOpcodeString(opcode).c_str()); +} + +Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs) { + return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", + HloOpcodeString(opcode).c_str()); +} + +void DfsHloVisitor::SetVisiting(const HloInstruction& instruction) { + VLOG(3) << "marking HLO " << &instruction << " as visiting: "; + CHECK(NotVisited(instruction)); + visit_state_[&instruction] = VisitState::kVisiting; +} + +void DfsHloVisitor::SetVisited(const HloInstruction& instruction) { + VLOG(3) << "marking HLO " << &instruction << " as visited: "; + CHECK(NotVisited(instruction) || IsVisiting(instruction)); + visit_state_[&instruction] = VisitState::kVisited; +} + +bool DfsHloVisitor::IsVisiting(const HloInstruction& instruction) { + if (visit_state_.count(&instruction) == 0) { + return false; + } + return visit_state_[&instruction] == VisitState::kVisiting; +} + +bool DfsHloVisitor::DidVisit(const HloInstruction& instruction) { + if (visit_state_.count(&instruction) == 0) { + return false; + } + return visit_state_[&instruction] == VisitState::kVisited; +} + +bool DfsHloVisitor::NotVisited(const HloInstruction& instruction) { + return visit_state_.count(&instruction) == 0 || + visit_state_[&instruction] == VisitState::kNotVisited; +} + +Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } + +Status DfsHloVisitor::Postprocess(HloInstruction* visited) { + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h new file mode 100644 index 0000000000..24fb6dee84 --- /dev/null +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloComputation; +class HloInstruction; + +// A postorder depth-first HloInstruction visitor. When Handle* is called on an +// instruction, all its operands were already visited. User code can subclass +// this to iterate over an HloInstruction DAG. The Handle* routines have +// operands / data unpacked for ease of use in the visitor subclass. +// +// No instruction will ever be visited twice; however, the root instruction will +// be reported again when the traversal is done via a call to FinishVisit. +// +// A subclass must override at least +// (either HandleElementwiseUnary or all the Handle methods for unary ops) and +// (either HandleElementwiseBinary or all the Handle methods for binary ops)). +// The default Handle methods for (unary, binary) ops call +// (HandleElementwiseUnary, HandleElementwiseBinary). +// The default (HandleElementwiseUnary, HandleElementwiseBinary) return an +// "unimplemented" error status. +// +// Note: this may change to an iterator in the future for flexibility purposes. +// +// TODO(b/26548304): Stop passing in information about the visited +// instruction that is accessible from the instruction object itself. +class DfsHloVisitor { + public: + DfsHloVisitor() + : visit_state_(32) // Start the hash table a bit larger to avoid resizes + {} + virtual ~DfsHloVisitor() {} + + // These routines are self-descriptive, see class comment for usage + // information. + + virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, + HloInstruction* operand); + virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs); + virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) = 0; + virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) = 0; + virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(maximum, HloOpcode::kMaximum, lhs, rhs); + } + virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(minimum, HloOpcode::kMinimum, lhs, rhs); + } + virtual Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) = 0; + virtual Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) { + return HandleElementwiseUnary(convert, HloOpcode::kConvert, operand); + } + virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) { + return HandleElementwiseUnary(copy, HloOpcode::kCopy, operand); + } + virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(multiply, HloOpcode::kMultiply, lhs, rhs); + } + virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) = 0; + virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(power, HloOpcode::kPower, lhs, rhs); + } + virtual Status HandleConvolution(HloInstruction* convolution, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window) = 0; + virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; + virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + return HandleElementwiseBinary(compare, opcode, lhs, rhs); + } + virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(add, HloOpcode::kAdd, lhs, rhs); + } + virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(divide, HloOpcode::kDivide, lhs, rhs); + } + virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(remainder, HloOpcode::kRemainder, lhs, rhs); + } + virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(subtract, HloOpcode::kSubtract, lhs, rhs); + } + virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + return HandleElementwiseUnary(abs, HloOpcode::kAbs, operand); + } + virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { + return HandleElementwiseUnary(sign, HloOpcode::kSign, operand); + } + virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { + return HandleElementwiseUnary(negate, HloOpcode::kNegate, operand); + } + virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { + return HandleElementwiseUnary(exp, HloOpcode::kExp, operand); + } + virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { + return HandleElementwiseUnary(floor, HloOpcode::kFloor, operand); + } + virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { + return HandleElementwiseUnary(ceil, HloOpcode::kCeil, operand); + } + virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { + return HandleElementwiseUnary(log, HloOpcode::kLog, operand); + } + virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { + return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); + } + virtual Status HandleLogicalAnd(HloInstruction* logical_and, + HloInstruction* lhs, HloInstruction* rhs) { + return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, + rhs); + } + virtual Status HandleLogicalNot(HloInstruction* logical_not, + HloInstruction* operand) { + return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot, operand); + } + virtual Status HandleLogicalOr(HloInstruction* logical_or, + HloInstruction* lhs, HloInstruction* rhs) { + return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs); + } + + virtual Status HandleInfeed(HloInstruction* infeed) = 0; + virtual Status HandleRng(HloInstruction* random, + RandomDistribution distribution) = 0; + virtual Status HandleReverse(HloInstruction* reverse, + HloInstruction* operand) = 0; + virtual Status HandleSort(HloInstruction* sort, HloInstruction* operand) = 0; + virtual Status HandleConstant(HloInstruction* constant, + const Literal& literal) = 0; + virtual Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) = 0; + virtual Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) = 0; + virtual Status HandleBitcast(HloInstruction* bitcast) = 0; + virtual Status HandleBroadcast(HloInstruction* broadcast) = 0; + virtual Status HandleReshape(HloInstruction* reshape) = 0; + virtual Status HandleTranspose(HloInstruction* transpose) = 0; + virtual Status HandleParameter(HloInstruction* parameter) = 0; + virtual Status HandleFusion(HloInstruction* fusion) = 0; + virtual Status HandleCall( + HloInstruction* call, + tensorflow::gtl::ArraySlice operands, + HloComputation* computation) = 0; + virtual Status HandleCustomCall( + HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) = 0; + virtual Status HandleSlice(HloInstruction* slice, + HloInstruction* operand) = 0; + virtual Status HandleDynamicSlice( + HloInstruction* slice, + tensorflow::gtl::ArraySlice operands) = 0; + virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) = 0; + virtual Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) = 0; + virtual Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice static_operands) = 0; + virtual Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, + const Window& window, + HloComputation* function) = 0; + virtual Status HandleSelectAndScatter(HloInstruction* instruction) = 0; + virtual Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, + HloComputation* condition, + HloComputation* body) = 0; + + virtual Status HandlePad(HloInstruction* pad) = 0; + + virtual Status HandleSend(HloInstruction* send) = 0; + + virtual Status HandleRecv(HloInstruction* recv) = 0; + + // Invoked to inform the visitor that the traversal has completed, and that + // the root was "root". + virtual Status FinishVisit(HloInstruction* root) = 0; + + // 3 possible visitation states of HLO instructions. Each instruction's + // state only flows one way: kNotVisited -> kVisiting -> kVisited. + enum VisitState { + kNotVisited, + kVisiting, + kVisited, + }; + + // Sets the visitation state of the given instruction as kVisiting. + // + // Precondition: current state must be kNotVisited. + void SetVisiting(const HloInstruction& instruction); + + // Sets the visitation state of the given instruction as kVisited. + // + // Precondition: current state must be either kNotVisited or kVisiting. + void SetVisited(const HloInstruction& instruction); + + // Returns whether the state of the given instruction is kVisiting. + bool IsVisiting(const HloInstruction& instruction); + + // Returns whether the state of the given instruction is kVisited. + bool DidVisit(const HloInstruction& instruction); + + // Returns whether the state of the given instruction is kNotVisited. + bool NotVisited(const HloInstruction& instruction); + + // This method should be overridden by subclasses that wish to run some + // operation on an op before its Handle* visitor method is called. + // + // For any HLO op, the order of calls is: + // + // Preprocess(op); + // Handle/OpType/(op); + // Postprocess(op); + // + // Overriding methods should call DfsHloVisitor::Preprocess before doing their + // own preprocessing. + virtual Status Preprocess(HloInstruction* hlo); + + // This method should be overridden by subclasses that wish to run some + // operation on an op after its Handle* visitor method is called. See + // Preprocess for more details. + // + // Overriding methods should call DfsHloVisitor::Postprocess after doing their + // own postprocessing. + virtual Status Postprocess(HloInstruction* visited); + + private: + // Tracks the visitation state of each instruction. Any instructions that are + // not found from the map are considered as VisitState::kNotVisited. + tensorflow::gtl::FlatMap visit_state_; + + TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitor); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h new file mode 100644 index 0000000000..4808c7a041 --- /dev/null +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloComputation; +class HloInstruction; + +// DfsHloVisitor with default action based on the HloInstruction being visited. +class DfsHloVisitorWithDefault : public DfsHloVisitor { + public: + DfsHloVisitorWithDefault() {} + ~DfsHloVisitorWithDefault() override {} + + // Default action performed on HloInstruction. + virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; + + Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, + HloInstruction* operand) override { + return DefaultAction(hlo); + } + Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs) override { + return DefaultAction(hlo); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* /*min*/, + HloInstruction* /*arg*/, + HloInstruction* /*max*/) override { + return DefaultAction(clamp); + } + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice /*operands*/) override { + return DefaultAction(concatenate); + } + Status HandleConvert(HloInstruction* convert, + HloInstruction* /*operand*/) override { + return DefaultAction(convert); + } + Status HandleCopy(HloInstruction* copy, + HloInstruction* /*operand*/) override { + return DefaultAction(copy); + } + Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/, + HloInstruction* /*on_true*/, + HloInstruction* /*on_false*/) override { + return DefaultAction(select); + } + Status HandleDot(HloInstruction* dot, HloInstruction* /*lhs*/, + HloInstruction* /*rhs*/) override { + return DefaultAction(dot); + } + Status HandleConvolution(HloInstruction* convolution, HloInstruction* /*lhs*/, + HloInstruction* /*rhs*/, + const Window& /*window*/) override { + return DefaultAction(convolution); + } + Status HandleCrossReplicaSum(HloInstruction* crs) override { + return DefaultAction(crs); + } + Status HandleCompare(HloInstruction* compare, HloOpcode /*opcode*/, + HloInstruction* /*lhs*/, + HloInstruction* /*rhs*/) override { + return DefaultAction(compare); + } + Status HandleRng(HloInstruction* random, + RandomDistribution /*distribution*/) override { + return DefaultAction(random); + } + Status HandleInfeed(HloInstruction* infeed) override { + return DefaultAction(infeed); + } + Status HandleReverse(HloInstruction* reverse, + HloInstruction* /*operand*/) override { + return DefaultAction(reverse); + } + Status HandleSort(HloInstruction* sort, + HloInstruction* /*operand*/) override { + return DefaultAction(sort); + } + Status HandleConstant(HloInstruction* constant, + const Literal& /*literal*/) override { + return DefaultAction(constant); + } + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* /*operand*/) override { + return DefaultAction(get_tuple_element); + } + Status HandleParameter(HloInstruction* parameter) override { + return DefaultAction(parameter); + } + Status HandleFusion(HloInstruction* fusion) override { + return DefaultAction(fusion); + } + Status HandleCall(HloInstruction* call, + tensorflow::gtl::ArraySlice /*operands*/, + HloComputation* /*computation*/) override { + return DefaultAction(call); + } + Status HandleCustomCall( + HloInstruction* custom_call, + tensorflow::gtl::ArraySlice /*operands*/, + tensorflow::StringPiece /*custom_call_target*/) override { + return DefaultAction(custom_call); + } + Status HandleSlice(HloInstruction* slice, + HloInstruction* /*operand*/) override { + return DefaultAction(slice); + } + Status HandleDynamicSlice( + HloInstruction* slice, + tensorflow::gtl::ArraySlice /*operands*/) override { + return DefaultAction(slice); + } + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* /*operand*/, + HloInstruction* /*update*/, + HloInstruction* /*start_indices*/) override { + return DefaultAction(dynamic_update_slice); + } + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice /*operands*/) override { + return DefaultAction(tuple); + } + Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice /*operands*/, + HloComputation* /*function*/, + tensorflow::gtl::ArraySlice /*static_operands*/) + override { + return DefaultAction(map); + } + Status HandleReduce(HloInstruction* reduce, HloInstruction* /*arg*/, + HloInstruction* /*init_value*/, + tensorflow::gtl::ArraySlice /*dimensions*/, + HloComputation* /*function*/) override { + return DefaultAction(reduce); + } + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* /*operand*/, + const Window& /*window*/, + HloComputation* /*function*/) override { + return DefaultAction(reduce_window); + } + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + return DefaultAction(select_and_scatter); + } + Status HandleBitcast(HloInstruction* bitcast) override { + return DefaultAction(bitcast); + } + Status HandleBroadcast(HloInstruction* broadcast) override { + return DefaultAction(broadcast); + } + Status HandlePad(HloInstruction* pad) override { return DefaultAction(pad); } + Status HandleReshape(HloInstruction* reshape) override { + return DefaultAction(reshape); + } + Status HandleTranspose(HloInstruction* transpose) override { + return DefaultAction(transpose); + } + Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/, + HloComputation* /*condition*/, + HloComputation* /*body*/) override { + return DefaultAction(xla_while); + } + Status HandleSend(HloInstruction* send) override { + return DefaultAction(send); + } + Status HandleRecv(HloInstruction* recv) override { + return DefaultAction(recv); + } + + // Invoked to inform the visitor that the traversal has completed, and that + // the root was "root". + Status FinishVisit(HloInstruction* /*root*/) override { return Status::OK(); } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefault); +}; + +// Helper class for Accept(VisitorFunction) which visits instructions in DFS +// order calling the given function at each instruction. +class FunctionVisitor : public DfsHloVisitorWithDefault { + public: + using VisitorFunction = std::function; + explicit FunctionVisitor(VisitorFunction visitor_func) + : visitor_func_(std::move(visitor_func)) {} + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return visitor_func_(hlo_instruction); + } + + private: + VisitorFunction visitor_func_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc new file mode 100644 index 0000000000..1a87a0043a --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -0,0 +1,934 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" + +#include +#include +#include +#include + +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Intrinsics.h" +#include "external/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +using llvm_ir::IrArray; +using llvm_ir::SetToFirstInsertPoint; + +StatusOr ElementalIrEmitter::EmitUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + if (op->opcode() == HloOpcode::kCopy) { + return operand_value; + } else { + return operand_value->getType()->isIntegerTy() + ? EmitIntegerUnaryOp(op, operand_value) + : EmitFloatUnaryOp(op, operand_value); + } +} + +StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + switch (op->opcode()) { + case HloOpcode::kConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsIntegralType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::IsIntegralType(to_type)) { + return ir_builder_->CreateIntCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_), + primitive_util::IsSignedIntegralType(to_type)); + } + if (primitive_util::IsFloatingPointType(to_type)) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder_->CreateSIToFP( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + } + if (primitive_util::IsUnsignedIntegralType(from_type)) { + return ir_builder_->CreateUIToFP( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + } + } + return Unimplemented("conversion from primitive type %s to %s", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str()); + } + case HloOpcode::kAbs: { + bool is_signed = + primitive_util::IsSignedIntegralType(op->shape().element_type()); + if (is_signed) { + auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), + ir_builder_); + auto zero = llvm::ConstantInt::get(type, 0); + auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero); + return ir_builder_->CreateSelect(cmp, operand_value, + ir_builder_->CreateNeg(operand_value)); + } else { + return operand_value; + } + } + case HloOpcode::kSign: { + bool is_signed = + primitive_util::IsSignedIntegralType(op->shape().element_type()); + auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), + ir_builder_); + auto zero = llvm::ConstantInt::get(type, 0); + auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero); + if (is_signed) { + auto ashr = ir_builder_->CreateAShr(operand_value, + type->getIntegerBitWidth() - 1); + return ir_builder_->CreateSelect(cmp, zero, + ir_builder_->CreateOr(ashr, 1)); + } else { + return ir_builder_->CreateSelect(cmp, zero, + llvm::ConstantInt::get(type, 1)); + } + } + case HloOpcode::kNegate: + return ir_builder_->CreateNeg(operand_value); + case HloOpcode::kLogicalNot: + // It is not sufficient to just call CreateNot() here because a PRED is + // represented as an i8 and the truth value is stored only in the bottom + // bit. + return ir_builder_->CreateZExt( + ir_builder_->CreateNot(ir_builder_->CreateTrunc( + operand_value, ir_builder_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + default: + return Unimplemented("unary integer op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + +StatusOr ElementalIrEmitter::EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + switch (op->opcode()) { + case HloOpcode::kConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsFloatingPointType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::IsFloatingPointType(to_type)) { + return ir_builder_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + } + if (primitive_util::IsSignedIntegralType(to_type)) { + return ir_builder_->CreateFPToSI( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + } + if (primitive_util::IsUnsignedIntegralType(to_type)) { + return ir_builder_->CreateFPToUI( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + } + return Unimplemented("unhandled conversion operation: %s => %s", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str()); + } + case HloOpcode::kExp: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, + {operand_value->getType()}, + ir_builder_); + case HloOpcode::kLog: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, + {operand_value->getType()}, + ir_builder_); + case HloOpcode::kFloor: + return llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, + ir_builder_); + case HloOpcode::kCeil: + return llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()}, + ir_builder_); + case HloOpcode::kAbs: + return llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()}, + ir_builder_); + case HloOpcode::kSign: { + // TODO(b/32151903): Ensure consistent sign behavior for -0.0 + auto type = operand_value->getType(); + auto zero = llvm::ConstantFP::get(type, 0.0); + auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); + auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero); + return ir_builder_->CreateSelect( + oeq, zero, + ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + llvm::ConstantFP::get(type, 1.0))); + } + case HloOpcode::kNegate: + return ir_builder_->CreateFNeg(operand_value); + default: + return Unimplemented("unary floating-point op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + +StatusOr ElementalIrEmitter::EmitBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + return lhs_value->getType()->isIntegerTy() + ? EmitIntegerBinaryOp(op, lhs_value, rhs_value, + primitive_util::IsSignedIntegralType( + op->operand(0)->shape().element_type())) + : EmitFloatBinaryOp(op, lhs_value, rhs_value); +} + +StatusOr ElementalIrEmitter::EmitFloatBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + switch (op->opcode()) { + case HloOpcode::kAdd: + return ir_builder_->CreateFAdd(lhs_value, rhs_value); + case HloOpcode::kSubtract: + return ir_builder_->CreateFSub(lhs_value, rhs_value); + case HloOpcode::kMultiply: + return ir_builder_->CreateFMul(lhs_value, rhs_value); + case HloOpcode::kDivide: + return ir_builder_->CreateFDiv(lhs_value, rhs_value); + case HloOpcode::kRemainder: + return ir_builder_->CreateFRem(lhs_value, rhs_value); + + // The 'O' prefix on the LLVM ops means "ordered" compare where comparisons + // with NAN always return false. + case HloOpcode::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_ONE, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kLt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kGt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kLe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kGe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, + rhs_value, ir_builder_); + + case HloOpcode::kMaximum: + return EmitFloatMax(lhs_value, rhs_value); + case HloOpcode::kMinimum: + return EmitFloatMin(lhs_value, rhs_value); + case HloOpcode::kPower: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, + {lhs_value, rhs_value}, + {lhs_value->getType()}, ir_builder_); + + default: + return Unimplemented("binary floating point op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + +llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, + {lhs_value, rhs_value}, + {lhs_value->getType()}, ir_builder_); +} + +llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, + {lhs_value, rhs_value}, + {lhs_value->getType()}, ir_builder_); +} + +StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, + llvm::Value* x) const { + if (prim_type != F32) { + return Unimplemented("inverse erf"); + } + auto getFloat = [&](const float f) { + return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); + }; + auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, + llvm::Value* w) { + llvm::Value* p = getFloat(coefficients.front()); + coefficients.pop_front(); + for (float coefficient : coefficients) { + p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w), + getFloat(coefficient)); + } + return p; + }; + + // Approximation for inverse error function from + // Giles, M., "Approximating the erfinv function". + // The approximation has the form: + // w = log((1-x)*(1+x)) + // if ( w < 5 ) { + // w = w - 2.5 + // p = sum_{i=1}^n lq[i]*w^i + // } else { + // w = sqrt(w) - 3 + // p = sum_{i=1}^n gq[i]*w^i + // } + // return p*x + llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()}); + + llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall( + logf_fn, + {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x), + ir_builder_->CreateFAdd(getFloat(1.0f), x))})); + + llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry( + ir_builder_->getFloatTy(), "p.addr", ir_builder_); + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)), + "w_less_than_five", ir_builder_); + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + { + llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f)); + tensorflow::gtl::ArraySlice lq{ + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + llvm::Value* p = multiply_add(lq, lw); + ir_builder_->CreateStore(p, p_addr); + } + + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()}); + + llvm::Value* gw = ir_builder_->CreateFSub( + ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); + tensorflow::gtl::ArraySlice gq{ + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + llvm::Value* p = multiply_add(gq, gw); + ir_builder_->CreateStore(p, p_addr); + } + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + llvm::Value* p = ir_builder_->CreateLoad(p_addr); + return ir_builder_->CreateFMul(p, x); +} + +StatusOr ElementalIrEmitter::EmitErfcInv( + PrimitiveType prim_type, llvm::Value* value) const { + // Compute erfcinv(value) by calculating erfinv(1.0 - value). + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_); + auto one = llvm::ConstantFP::get(type, 1.0); + return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); +} + +StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const { + switch (op->opcode()) { + // TODO(jingyue): add the "nsw" attribute for signed types. + case HloOpcode::kAdd: + return ir_builder_->CreateAdd(lhs_value, rhs_value); + case HloOpcode::kSubtract: + return ir_builder_->CreateSub(lhs_value, rhs_value); + case HloOpcode::kMultiply: + return ir_builder_->CreateMul(lhs_value, rhs_value); + case HloOpcode::kDivide: + return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value) + : ir_builder_->CreateUDiv(lhs_value, rhs_value); + case HloOpcode::kRemainder: + return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value) + : ir_builder_->CreateURem(lhs_value, rhs_value); + case HloOpcode::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, + rhs_value, ir_builder_); + case HloOpcode::kLt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, + lhs_value, rhs_value, ir_builder_); + case HloOpcode::kGt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, + lhs_value, rhs_value, ir_builder_); + case HloOpcode::kLe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, + lhs_value, rhs_value, ir_builder_); + case HloOpcode::kGe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, + lhs_value, rhs_value, ir_builder_); + case HloOpcode::kMinimum: + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); + case HloOpcode::kMaximum: + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); + case HloOpcode::kLogicalAnd: + return ir_builder_->CreateAnd(lhs_value, rhs_value); + case HloOpcode::kLogicalOr: + return ir_builder_->CreateOr(lhs_value, rhs_value); + default: + return Unimplemented("binary integer op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + +llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( + const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, + int64 operand_no) const { + CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() + << " is not elementwise."; + + const Shape& operand_shape = hlo.operand(operand_no)->shape(); + // If the operand is scalar, the source index is always {}. + if (ShapeUtil::IsScalar(operand_shape)) { + return llvm_ir::IrArray::Index(); + } + + // If no implicit broadcast is needed for this operand, returns the target + // index as the source index. + if (ShapeUtil::Compatible(operand_shape, hlo.shape())) { + return target_index; + } + + // If implicit broadcast is needed, the source dimensions that are broadcast + // have index 0. + CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); + llvm_ir::IrArray::Index source_index; + for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { + if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { + source_index.push_back(target_index[i]); + } else { + CHECK_EQ(1, operand_shape.dimensions(i)); + source_index.push_back(ir_builder_->getInt64(0)); + } + } + return source_index; +} + +llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) + const { + PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); + llvm::Type* param_ir_type = + llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_); + + // Same values as PCG library + // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h + llvm::Value* multiplier = ir_builder_->getInt( + llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4})); + llvm::Value* increment = ir_builder_->getInt( + llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); + + auto random_value = [hlo]() { + CHECK(hlo->parent() != nullptr && hlo->parent()->parent() != nullptr); + const HloModule* module = hlo->parent()->parent(); + return module->RandomNew64(); + }; + + // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the + // compilation order is deterministic (i.e., RandomNew64 invocation order is + // deterministic), then the order of RNG is deterministic for a given seed and + // hence tests will be deterministic. + // If the user provides a global seed instruction then we only use 64-bits of + // the host's random number generator to seed the 128 bit value with the other + // 64-bits is due to a user specified global seed instruction. + // Create a GlobalVariable to maintain state between invocations. There is a + // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit + // values. + llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable( + /*M=*/*module_, + /*Ty=*/ir_builder_->getInt64Ty(), + /*isConstant=*/false, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/ir_builder_->getInt64(random_value()), + /*Name=*/"state_ptr0"); + uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() + : random_value(); + llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( + /*M=*/*module_, + /*Ty=*/ir_builder_->getInt64Ty(), + /*isConstant=*/false, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/ir_builder_->getInt64(graph_seed), + /*Name=*/"state_ptr1"); + + // We want each thread to use its own stream, so we modify the increment per + // thread. We want the increment to remain odd, so we shift the thread id left + // 1 and add it to the increment. + increment = ir_builder_->CreateAdd(increment, + ir_builder_->CreateShl(EmitThreadId(), 1)); + + // PCG-XSL-RR algorithm + // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf + // state = multiplier * state + increment + // return uint64_t(state ^ (state >> 64))) >>> (state >> 122) + // where ">>>" is bitwise rotation + auto get_next_i64 = [=]() { + llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc( + ir_builder_->CreateLoad(state_ptr0, "state0"), + ir_builder_->getInt128Ty()); + llvm::Value* state1 = ir_builder_->CreateShl( + ir_builder_->CreateZExtOrTrunc( + ir_builder_->CreateLoad(state_ptr1, "state1"), + ir_builder_->getInt128Ty()), + 64); + llvm::Value* state = ir_builder_->CreateOr(state0, state1); + llvm::Value* updated = ir_builder_->CreateAdd( + ir_builder_->CreateMul(state, multiplier), increment); + ir_builder_->CreateStore( + ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()), + state_ptr0); + ir_builder_->CreateStore( + ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64), + ir_builder_->getInt64Ty()), + state_ptr1); + + return llvm_ir::CreateRor( + ir_builder_->CreateTrunc( + ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)), + ir_builder_->getInt64Ty()), + ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122), + ir_builder_->getInt64Ty()), + ir_builder_); + }; + + auto get_next_uniform_float = [=]() { + return ir_builder_->CreateFDiv( + ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type), + llvm::ConstantFP::get(param_ir_type, 0x1p64)); + }; + + return [=](const llvm_ir::IrArray::Index& index) -> StatusOr { + switch (hlo->random_distribution()) { + case RNG_UNIFORM: { + TF_ASSIGN_OR_RETURN(llvm::Value * p, + operand_to_generator.at(hlo->operand(0))(index)); + TF_ASSIGN_OR_RETURN(llvm::Value * q, + operand_to_generator.at(hlo->operand(1))(index)); + if (primitive_util::IsFloatingPointType(param_prim_type)) { + return ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p), + get_next_uniform_float()), + p); + } else { + auto r = ir_builder_->CreateSub(q, p); + auto leading_zeros = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)}, + {param_ir_type}, ir_builder_); + auto in_block = ir_builder_->GetInsertBlock(); + auto body_block = in_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_body"); + SetToFirstInsertPoint(body_block, ir_builder_); + auto out_block = body_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_out"); + SetToFirstInsertPoint(body_block, ir_builder_); + auto random = ir_builder_->CreateAnd( + ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), + ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), + leading_zeros)); + llvm::ReplaceInstWithInst( + body_block->getTerminator(), + llvm::BranchInst::Create(out_block, body_block, + ir_builder_->CreateICmpULE(random, r))); + SetToFirstInsertPoint(out_block, ir_builder_); + return ir_builder_->CreateAdd( + p, ir_builder_->CreateSelect( + ir_builder_->CreateICmpEQ(p, q), + llvm::ConstantInt::get(param_ir_type, 0), random)); + } + } + case RNG_NORMAL: { + TF_ASSIGN_OR_RETURN(llvm::Value * m, + operand_to_generator.at(hlo->operand(0))(index)); + TF_ASSIGN_OR_RETURN(llvm::Value * s, + operand_to_generator.at(hlo->operand(1))(index)); + TF_ASSIGN_OR_RETURN( + llvm::Value * r, + EmitErfcInv(param_prim_type, + ir_builder_->CreateFMul( + llvm::ConstantFP::get(param_ir_type, 2.0), + get_next_uniform_float()))); + return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m); + } + case RNG_BERNOULLI: { + TF_ASSIGN_OR_RETURN(llvm::Value * p, + operand_to_generator.at(hlo->operand(0))(index)); + return ir_builder_->CreateZExt( + ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + ir_builder_)); + } + default: + return InvalidArgument( + "unhandled distribution %s", + RandomDistribution_Name(hlo->random_distribution()).c_str()); + } + }; +} + +llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) + const { + // TODO(mfdyck): Make capture lists explicit, lest someone forget to cap + // `operand_to_generator` by ref and its many copies fill memory and cause + // much woe and process death. + switch (hlo->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kCeil: + case HloOpcode::kConvert: + case HloOpcode::kCopy: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kTanh: + case HloOpcode::kLogicalNot: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + return EmitUnaryOp(hlo, operand_value); + }; + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalOr: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + const HloInstruction* lhs = hlo->operand(0); + const HloInstruction* rhs = hlo->operand(1); + TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, + operand_to_generator.at(lhs)( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, + operand_to_generator.at(rhs)( + ElementwiseSourceIndex(index, *hlo, 1))); + return EmitBinaryOp(hlo, lhs_value, rhs_value); + }; + case HloOpcode::kSelect: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + return ir_builder_->CreateSelect( + ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), + on_true_value, on_false_value); + }; + case HloOpcode::kClamp: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_ASSIGN_OR_RETURN(llvm::Value * min_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * max_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + }; + case HloOpcode::kConcatenate: + return [=, &operand_to_generator]( + const IrArray::Index target_index) -> StatusOr { + const int64 concat_dim = hlo->dimensions(0); + auto source_index = target_index; + + llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + ir_builder_), + hlo->operands().size()); + llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + auto prior_insert_point = ir_builder_->GetInsertPoint(); + llvm::BasicBlock* exit_block = + init_block->splitBasicBlock(output, "concat_merge"); + + ir_builder_->SetInsertPoint(init_block); + init_block->getTerminator()->eraseFromParent(); + + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + const HloInstruction* operand = hlo->operand(operand_idx); + auto true_block = llvm_ir::CreateBasicBlock( + exit_block, tensorflow::strings::StrCat( + "concat_index_from_operand", operand_idx), + ir_builder_); + auto false_block = llvm_ir::CreateBasicBlock( + exit_block, tensorflow::strings::StrCat( + "concat_index_not_from_operand", operand_idx), + ir_builder_); + auto concat_dim_size = + llvm::ConstantInt::get(source_index[concat_dim]->getType(), + operand->shape().dimensions(concat_dim)); + ir_builder_->CreateCondBr( + ir_builder_->CreateICmpULT(source_index[concat_dim], + concat_dim_size), + true_block, false_block); + + // Create the terminator of the true block before calling operand + // generators, because they require non-degenerate basic blocks. + ir_builder_->SetInsertPoint( + llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(source_index)); + output->addIncoming(value, ir_builder_->GetInsertBlock()); + + // Subtract the size of the concat dimension of the current operand + // from the source index. + ir_builder_->SetInsertPoint(false_block); + source_index[concat_dim] = + ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + } + + ir_builder_->CreateUnreachable(); + ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + return output; + }; + case HloOpcode::kReverse: + return [=, &operand_to_generator]( + const IrArray::Index& target_index) -> StatusOr { + const HloInstruction* operand = hlo->operand(0); + auto source_index = target_index; + for (int64 dim : hlo->dimensions()) { + source_index[dim] = ir_builder_->CreateSub( + llvm::ConstantInt::get(target_index[dim]->getType(), + hlo->shape().dimensions(dim) - 1), + target_index[dim]); + } + return operand_to_generator.at(operand)(source_index); + }; + case HloOpcode::kBroadcast: + return [=, &operand_to_generator]( + const IrArray::Index& target_index) -> StatusOr { + // The `dimensions` member of the broadcast instruction maps from + // input dimensions to output dimensions. + const HloInstruction* operand = hlo->operand(0); + int64 rank = ShapeUtil::Rank(operand->shape()); + IrArray::Index source_index(rank); + for (int64 i = 0; i < rank; ++i) { + source_index[i] = target_index[hlo->dimensions(i)]; + } + return operand_to_generator.at(operand)(source_index); + }; + case HloOpcode::kSlice: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + IrArray::Index sliced_index(index.size()); + for (int i = 0; i < index.size(); ++i) { + sliced_index[i] = ir_builder_->CreateAdd( + index[i], llvm::ConstantInt::get(index[i]->getType(), + hlo->slice_starts(i))); + } + return operand_to_generator.at(hlo->operand(0))(sliced_index); + }; + case HloOpcode::kDynamicSlice: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + // Emit IR to read dynamic start indices from hlo->operand(1). + const HloInstruction* input_hlo = hlo->operand(0); + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + slice_start_index[i] = start_index_value; + } + + llvm_ir::IrArray::Index input_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR which computes: + // input_index = (start_index + offset_index) % dim_size + // Security note: this is the code that keeps the indices in-bounds. + llvm::Value* dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( + slice_start_index[i], index[i]->getType()); + input_index[i] = ir_builder_->CreateURem( + ir_builder_->CreateAdd(start_index, index[i]), dim_size); + } + return operand_to_generator.at(input_hlo)(input_index); + }; + case HloOpcode::kDynamicUpdateSlice: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + const HloInstruction* input_hlo = hlo->operand(0); + const HloInstruction* update_hlo = hlo->operand(1); + const HloInstruction* start_hlo = hlo->operand(2); + // Calculate slice start/end indices. + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + llvm_ir::IrArray::Index slice_limit_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR to read dynamic start indices from 'start_hlo'. + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(start_hlo)(dim_index)); + slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( + start_index_value, index[i]->getType()); + // Emit IR to compute: slice_limit_index = start_index + update_dim + // NOTE: Although 'start_indices' is dynamic and could be + // out-of-range, we do not compute 'slice_limit_index' mod input dim + // size here, because subsequent array index calculations will be + // computed mod input dim size for safety. + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + slice_limit_index[i] = + ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + } + + // Check if 'index' intersects start/end indices. + llvm::Value* slice_intersection = + llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); + + for (int64 i = 0; i < rank; ++i) { + // Check that index[i] >= slice_start_index[i]. + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + + // Check that index[i] < slice_limit_index[i]. + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); + } + + // Emit: + // if (slice_intersection) -> return data from 'update'. + // else -> return data from 'index'. + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + ir_builder_), + "ret_value_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + slice_intersection, "slice_intersection", ir_builder_); + + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + // Compute update index for intersection case. + llvm_ir::IrArray::Index update_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + // NOTE: Subtraction will be positive due to bounds checking above. + update_index[i] = ir_builder_->CreateURem( + ir_builder_->CreateSub(index[i], slice_start_index[i]), + update_dim_size); + } + TF_ASSIGN_OR_RETURN(llvm::Value * true_value, + operand_to_generator.at(update_hlo)(update_index)); + ir_builder_->CreateStore(true_value, ret_value_addr); + + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * false_value, + operand_to_generator.at(input_hlo)(index)); + ir_builder_->CreateStore(false_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + return ir_builder_->CreateLoad(ret_value_addr); + }; + case HloOpcode::kReshape: + CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), + ShapeUtil::ElementsIn(hlo->operand(0)->shape())); + return [=, &operand_to_generator](const IrArray::Index& index) { + const HloInstruction* operand = hlo->operand(0); + return operand_to_generator.at(operand)(index.SourceIndexOfReshape( + hlo->shape(), operand->shape(), ir_builder_)); + }; + case HloOpcode::kTranspose: + return [=, &operand_to_generator](const IrArray::Index& target_index) { + return operand_to_generator.at(hlo->operand(0))( + target_index.SourceIndexOfTranspose( + hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), + ir_builder_)); + }; + case HloOpcode::kRng: + return MakeRngElementGenerator(hlo, operand_to_generator); + default: + return [=, &operand_to_generator](const IrArray::Index& index) { + return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str()); + }; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h new file mode 100644 index 0000000000..2576d3823e --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ + +#include + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +class ElementalIrEmitter { + public: + using HloToElementGeneratorMap = + std::unordered_map; + + ElementalIrEmitter(const HloModuleConfig& hlo_module_config, + llvm::Module* module, llvm::IRBuilder<>* ir_builder) + : ir_builder_(ir_builder), + module_(module), + hlo_module_config_(hlo_module_config) {} + + virtual ~ElementalIrEmitter() {} + + virtual StatusOr EmitUnaryOp(const HloInstruction* op, + llvm::Value* operand_value) const; + + virtual StatusOr EmitBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value) const; + + // Returns a function to generate an element of the output of `hlo`, given a + // map of functions to generate elements of its operands. + virtual llvm_ir::ElementGenerator MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const; + + llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + + protected: + virtual StatusOr EmitIntegerUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const; + + virtual StatusOr EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const; + + virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const; + + virtual StatusOr EmitFloatBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const; + + virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, + llvm::Value* rhs_value) const; + + virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, + llvm::Value* rhs_value) const; + + virtual StatusOr EmitErfInv(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) const; + + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and + // the target array index, computes the source array index of its + // `operand_no`-th operand. + // + // Precondition: `hlo` is an elementwise op. + llvm_ir::IrArray::Index ElementwiseSourceIndex( + const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, + int64 operand_no) const; + + // Identifier of the thread unique among all threads on the device + virtual llvm::Value* EmitThreadId() const { + return ir_builder_->getIntN(128, 0); + } + + llvm::IRBuilder<>* const ir_builder_; + + llvm::Module* module_; + + // The HloModuleConfig which gathers all settings and values which affect the + // compiled executable outside of the HLO code itself. + const HloModuleConfig& hlo_module_config_; + + private: + // Returns a ElementGenerator for a RNG HloInstruction. + llvm_ir::ElementGenerator MakeRngElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc new file mode 100644 index 0000000000..5b1a5a16d1 --- /dev/null +++ b/tensorflow/compiler/xla/service/executable.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/executable.h" + +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { + +StatusOr> +Executable::ExecuteOnStreams( + tensorflow::gtl::ArraySlice run_options, + tensorflow::gtl::ArraySlice< + tensorflow::gtl::ArraySlice> + arguments) { + TF_RET_CHECK(run_options.size() == arguments.size()); + + if (run_options.size() == 1) { + TF_ASSIGN_OR_RETURN(auto result, + ExecuteOnStream(&run_options[0], arguments[0], + /*hlo_execution_profile=*/nullptr)); + return std::vector({result}); + } + + std::vector return_values( + run_options.size()); + for (int64 i = 0; i < run_options.size(); ++i) { + // We cannot BlockHostUntilDone() on the already-launched executions in case + // of error, since if the executions communicate, the initially launched + // executions may never complete if not all executions are running. + TF_ASSIGN_OR_RETURN(return_values[i], + ExecuteAsyncOnStream(&run_options[i], arguments[i])); + } + for (const auto& options : run_options) { + TF_RET_CHECK(options.stream() != nullptr); + options.stream()->BlockHostUntilDone(); + } + return return_values; +} + +Status Executable::DumpSessionModule() { + TF_RET_CHECK(dumping()); + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_executions_to; + VersionedComputationHandle versioned_handle = entry_computation_handle(); + // This filename does not include the version number because the computation + // is only ever executed at one version. + string filename = tensorflow::strings::Printf( + "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), + session_module_->entry().name().c_str(), ++execution_count_); + return Executable::DumpToDirectory(directory_path, filename, + *session_module_); +} + +/* static */ Status Executable::DumpToDirectory( + const string& directory_path, const string& filename, + const SessionModule& session_module) { + tensorflow::Env* env = tensorflow::Env::Default(); + if (!env->IsDirectory(directory_path).ok()) { + TF_RETURN_IF_ERROR(env->CreateDir(directory_path)); + } + string file_path = tensorflow::io::JoinPath(directory_path, filename); + return tensorflow::WriteBinaryProto(env, file_path, session_module); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h new file mode 100644 index 0000000000..373ab79ab2 --- /dev/null +++ b/tensorflow/compiler/xla/service/executable.h @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// A given platform's compiler will produce an Executable -- this is a uniform +// interface that is used for launching compiled programs across platforms. +// +// TODO(leary) will need to extend this to support multiple streams/devices as +// we begin to compile single programs to run on multiple devices. +class Executable { + public: + explicit Executable(std::unique_ptr hlo_module, + std::unique_ptr module_config) + : hlo_module_(std::move(hlo_module)), + module_config_(std::move(module_config)) {} + virtual ~Executable() {} + + // Enqueues the compilation result on the provided stream, passing the given + // arguments. This call is blocking and returns after the execution is done. + // + // If the hlo_execution_profile is provided as non-nullptr, profiling will be + // enabled. + // + // Returns the device memory region that a successful execution would + // populate. + virtual StatusOr ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + HloExecutionProfile* hlo_execution_profile) = 0; + + // Overload of ExecuteOnStream which returns and takes arguments as + // ShapedBuffers. Used for LocalService execution. + virtual StatusOr> ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) = 0; + + // Overload of which writes the result into a pre-allocated buffer + // (result_buffer). + virtual Status ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, + HloExecutionProfile* hlo_execution_profile) = 0; + + // Same as ExecuteOnStream(), but this call is non-blocking and returns as + // soon as all of the operations are enqueued for launch on the stream. + virtual StatusOr ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments) = 0; + + // Same as ExecuteOnStream(), but runs this executable on multiple + // streams. arguments[i] contains the arguments to the execution on + // run_options[i]->stream() and the returned value is at index i of the + // returned vector. + virtual StatusOr> + ExecuteOnStreams( + tensorflow::gtl::ArraySlice run_options, + tensorflow::gtl::ArraySlice< + tensorflow::gtl::ArraySlice> + arguments); + + // Returns the ExecutionProfile from executing on the device. This includes + // the number of cycles taken for the computation or the compilation time. + ExecutionProfile execution_profile() const { + tensorflow::mutex_lock lock(mutex_); + return execution_profile_; + } + + // Returns whether this executable was compiled with HLO profilings support + // enabled. If not, the caller should not expect an hlo_execution_profile + // passed to ExecuteOnStream above to be populated during execution. + bool hlo_profiling_enabled() const { + return module_config_->hlo_profiling_enabled(); + } + + const HloModule& module() const { return *hlo_module_; } + + const HloModuleConfig& module_config() const { return *module_config_; } + + // Returns the versioned computation handle of the computation computed by + // this executable. + const VersionedComputationHandle& entry_computation_handle() const { + return hlo_module_->entry_computation_handle(); + } + + // The shape (including layout) that results from this execution. This is the + // shape of the DeviceMemoryBase result value in ExecuteOnStream above. + const Shape& result_shape() const { + return module_config_->entry_computation_layout().result_shape(); + } + + // Dumping helpers. + void set_session_module(std::unique_ptr session_module) { + session_module_ = std::move(session_module); + } + bool dumping() const { return session_module_ != nullptr; } + SessionModule* session_module() const { return session_module_.get(); } + Status DumpSessionModule(); + + // Dump session_module to directory_path/filename. + static Status DumpToDirectory(const string& directory_path, + const string& filename, + const SessionModule& session_module); + + protected: + mutable tensorflow::mutex mutex_; + + // Execution profile data on the device. + ExecutionProfile execution_profile_ GUARDED_BY(mutex_); + + // HloModule this was compiled from. BufferAssignment keeps pointers to + // HloInstructions owned by the HloModule so we need to keep the HloModule + // around. + std::unique_ptr hlo_module_; + + // The configuration used to build this executable (parameter layouts, result + // layout, profiling enabled, etc). + std::unique_ptr module_config_; + + // SessionModule this was compiled from. Null if not dumping executions. + std::unique_ptr session_module_; + + // Execution count, used to generate a unique filename for each dumped + // execution. + int64 execution_count_ = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc new file mode 100644 index 0000000000..cf1870580c --- /dev/null +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/execution_tracker.h" + +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +AsyncExecution::AsyncExecution( + Backend* backend, + std::vector> streams, + const ExecutionProfile& profile, GlobalDataHandle result) + : backend_(CHECK_NOTNULL(backend)), + streams_(std::move(streams)), + profile_(profile), + result_(result) { + for (const auto& stream : streams_) { + CHECK(stream != nullptr); + } +} + +AsyncExecution::~AsyncExecution() { + for (auto& stream : streams_) { + backend_->ReleaseStream(std::move(stream)); + } +} + +tensorflow::Status AsyncExecution::BlockUntilDone() const { + for (auto& stream : streams_) { + if (!stream->BlockHostUntilDone()) { + return InternalError("failed to block until done"); + } + } + return tensorflow::Status::OK(); +} + +ExecutionTracker::ExecutionTracker() : next_handle_(1) {} + +ExecutionHandle ExecutionTracker::Register( + Backend* backend, + std::vector> streams, + const ExecutionProfile& profile, GlobalDataHandle result) { + tensorflow::mutex_lock lock(execution_mutex_); + int64 handle = next_handle_++; + auto inserted = handle_to_execution_.emplace( + handle, + MakeUnique(backend, std::move(streams), profile, result)); + CHECK(inserted.second); + + ExecutionHandle execution_handle; + execution_handle.set_handle(handle); + return execution_handle; +} + +tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { + tensorflow::mutex_lock lock(execution_mutex_); + auto it = handle_to_execution_.find(handle.handle()); + if (it == handle_to_execution_.end()) { + return NotFound("no execution record for execution handle: %lld", + handle.handle()); + } + handle_to_execution_.erase(handle.handle()); + return tensorflow::Status::OK(); +} + +StatusOr ExecutionTracker::Resolve( + const ExecutionHandle& handle) { + tensorflow::mutex_lock lock(execution_mutex_); + auto it = handle_to_execution_.find(handle.handle()); + if (it == handle_to_execution_.end()) { + return NotFound("no execution record for execution handle: %lld", + handle.handle()); + } + return it->second.get(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h new file mode 100644 index 0000000000..99a5bb5ad9 --- /dev/null +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Represents an asynchronously launched execution. Owns the stream (from the +// passed run_options->stream()) on which the execution is launched and releases +// the stream when destructed. +class AsyncExecution { + public: + AsyncExecution( + Backend* backend, + std::vector> streams, + const ExecutionProfile& profile, GlobalDataHandle result); + + ~AsyncExecution(); + tensorflow::Status BlockUntilDone() const; + + const GlobalDataHandle& result() const { return result_; } + + const ExecutionProfile& profile() const { return profile_; } + + private: + // Backend to execute the computation on. + Backend* backend_; + + // Stream on which the execution is launched. + std::vector> streams_; + + // Profile object of the execution to be returned to the user. + ExecutionProfile profile_; + + // Data handle to the result of the execution. Data represented by this handle + // is valid only after BlockUntilDone() is called. + GlobalDataHandle result_; +}; + +// Tracks asynchronously launched executions for the XLA service. +class ExecutionTracker { + public: + ExecutionTracker(); + + // Registers an execution with its backend, streams, and data handle to the + // execution result. Returns a handle for the registered execution. + ExecutionHandle Register( + Backend* backend, + std::vector> stream, + const ExecutionProfile& profile, GlobalDataHandle data); + + // Unregisters the execution for the given handle. + tensorflow::Status Unregister(const ExecutionHandle& handle); + + // Resolves the given ExecutionHandle to an AsyncExecution. Returns an + // error status if the given handle is not found, which means that the + // execution is not yet registered or already unregistered. + StatusOr Resolve(const ExecutionHandle& handle); + + private: + // The next handle to assign to an execution. + int64 next_handle_ GUARDED_BY(execution_mutex_); + + // Mapping from ExecutionHandle handle to the corresponding registered + // AsyncExecution object. + std::map> handle_to_execution_ + GUARDED_BY(execution_mutex_); + + tensorflow::mutex execution_mutex_; // Guards the execution mapping. + + TF_DISALLOW_COPY_AND_ASSIGN(ExecutionTracker); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc new file mode 100644 index 0000000000..086306696d --- /dev/null +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id) + : platform_id_(platform_id) { + // We currently only support kHostPlatformId for CPU and kCudaPlatformId for + // GPU. Before supporting other platforms, we need to test this transfer + // manager on them. + CHECK(platform_id_ == se::host::kHostPlatformId || + platform_id_ == se::cuda::kCudaPlatformId); +} + +se::Platform::Id GenericTransferManager::PlatformId() const { + if (platform_id_ == se::cuda::kCudaPlatformId || + platform_id_ == se::host::kHostPlatformId) { + return platform_id_; + } + CHECK(false) << "GenericTransferManager::platform_id_ is invalid"; +} + +Status GenericTransferManager::TransferLiteralFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& device_shape, const Shape& literal_shape, Literal* literal) { + VLOG(2) << "transferring literal shape from device: " + << ShapeUtil::HumanString(literal_shape) + << "; device location: " << source.opaque(); + TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); + + // Tuples are a special case and contain one or more shapes inside of them to + // an arbitrary nesting depth. + if (device_shape.element_type() == TUPLE) { + *literal->mutable_shape() = literal_shape; + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + ShallowCopyTupleFromDevice(executor, source, device_shape)); + TF_RET_CHECK(element_buffers.size() == + ShapeUtil::TupleElementCount(device_shape)); + for (int64 i = 0; i < element_buffers.size(); ++i) { + const Shape& element_device_shape = device_shape.tuple_shapes(i); + const Shape& element_literal_shape = literal_shape.tuple_shapes(i); + Literal* element_literal = literal->add_tuple_literals(); + // Recursively call TransferFromDevice to copy over the data in the + // element array. + TF_RETURN_IF_ERROR(TransferLiteralFromDevice( + executor, element_buffers[i], /*device_shape=*/element_device_shape, + /*literal_shape=*/element_literal_shape, element_literal)); + } + return Status::OK(); + } + + *literal->mutable_shape() = device_shape; + LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + TF_RETURN_IF_ERROR(TransferBufferFromDevice( + executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), + /*destination=*/LiteralUtil::MutableInternalData(literal))); + if (!ShapeUtil::Equal(literal_shape, device_shape)) { + literal->Swap( + LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + } + TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); + return Status::OK(); +} + +StatusOr> +GenericTransferManager::ShallowCopyTupleFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsTuple(shape)); + + // For devices which use the GenericTransferManager, a tuple is stored as an + // array of pointers to buffers. Copy the contents of the tuple buffer into + // a vector of void* pointers. + std::vector element_pointers(ShapeUtil::TupleElementCount(shape), + nullptr); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape); + auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, + element_pointers.data()); + if (!copy_status.ok()) { + return AddStatus( + Status(static_cast(copy_status.code()), + copy_status.error_message()), + "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); + } + + // Create a DeviceMemoryBase from each void* pointer. + std::vector destination; + for (int i = 0; i < element_pointers.size(); ++i) { + if (element_pointers[i] == nullptr && + !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { + return FailedPrecondition("tuple contains nullptr at element %d", i); + } + int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i)); + destination.emplace_back(element_pointers[i], buffer_size); + } + return std::move(destination); +} + +Status GenericTransferManager::TransferLiteralToDevice( + se::StreamExecutor* executor, const Literal& literal, + se::DeviceMemoryBase* destination) { + const Shape& shape = literal.shape(); + VLOG(2) << "transferring literal shape to device: " + << ShapeUtil::HumanString(shape) + << "; device location: " << destination->opaque(); + + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector tuple_elements_on_device; + for (const Literal& tuple_element : literal.tuple_literals()) { + se::DeviceMemoryBase allocation = executor->AllocateArray( + GetByteSizeRequirement(tuple_element.shape())); + TF_RETURN_IF_ERROR( + TransferLiteralToDevice(executor, tuple_element, &allocation)); + tuple_elements_on_device.push_back(allocation.opaque()); + } + return TransferBufferToDevice( + executor, tuple_elements_on_device.size() * sizeof(void*), + tuple_elements_on_device.data(), destination); + } + + return TransferBufferToDevice( + executor, /*size=*/GetByteSizeRequirement(shape), + /*source=*/LiteralUtil::InternalData(literal), destination); +} + +Status GenericTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const Literal& literal) { + return Unimplemented("Infeed is not supported on GPU (b/30467474)"); +} + +Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) { + return Unimplemented( + "Device reset is not yet supported on CPU and GPU (b/30481585)"); +} + +int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); +} + +} // namespace xla + +static xla::TransferManager* CreateGenericTransferManager() { + return new xla::GenericTransferManager(se::cuda::kCudaPlatformId); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId, + CreateGenericTransferManager); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h new file mode 100644 index 0000000000..cfa02bf22f --- /dev/null +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ + +#include + +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A generic implementation of the XLA TransferManager interface +// that is the base class for both CPU and GPU. For GPU, it transfers +// data between host and device (GPU). For CPU, since the "device" +// here is the host itself, there's not much for this transfer manager +// to do except memcpy the result. There is a CpuTransferManager that +// inherits from GenericTransferManager and handles CPU-specific +// infeed. +class GenericTransferManager : public TransferManager { + public: + explicit GenericTransferManager( + perftools::gputools::Platform::Id platform_id); + ~GenericTransferManager() override {} + + perftools::gputools::Platform::Id PlatformId() const override; + + Status TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, + const Shape& device_shape, const Shape& literal_shape, + Literal* literal) override; + + Status TransferLiteralToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + perftools::gputools::DeviceMemoryBase* destination) override; + + Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, + const Literal& literal) override; + + Status ResetDevice(perftools::gputools::StreamExecutor* executor) override; + + StatusOr> + ShallowCopyTupleFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, + const Shape& shape) override; + + int64 GetByteSizeRequirement(const Shape& shape) override; + + private: + // The platform this transfer manager targets. + perftools::gputools::Platform::Id platform_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(GenericTransferManager); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD new file mode 100644 index 0000000000..9aeebe42f8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -0,0 +1,533 @@ +# Description: +# GPU-specific components in XLA service implementation. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "partition_assignment", + srcs = [ + "partition_assignment.cc", + ], + hdrs = [ + "partition_assignment.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +# TODO(b/29140563) This target is flaky, disabled until flakiness is +# root-caused. Failed on 2016-06-08. +#cc_test( +# name = "partition_assignment_test", +# srcs = [ +# "partition_assignment_test.cc", +# ], +# tags = [ +# "requires-gpu-sm35", +# ], +# deps = [ +# ":partition_assignment", +# "//tensorflow/core:stream_executor_no_cuda", +# "//tensorflow/compiler/xla:shape_util", +# "//tensorflow/compiler/xla:xla_data_proto", +# "//tensorflow/compiler/xla/service:gpu_plugin", +# "//tensorflow/compiler/xla/service:hlo", +# "//tensorflow/compiler/xla/tests:hlo_test_base", +# "//tensorflow/core:test_main", +# ], +#) + +cc_library( + name = "stream_assignment", + srcs = ["stream_assignment.cc"], + hdrs = ["stream_assignment.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:stream_assignment_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "stream_assignment_test", + srcs = [ + "stream_assignment_test.cc", + ], + deps = [ + ":stream_assignment", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_to_ir_bindings", + srcs = ["hlo_to_ir_bindings.cc"], + hdrs = ["hlo_to_ir_bindings.h"], + deps = [ + ":ir_emission_utils", + ":temp_buffer_offsets", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "ir_emitter", + srcs = [ + "ir_emitter.cc", + "ir_emitter_nested.cc", + "ir_emitter_unnested.cc", + ], + hdrs = [ + "ir_emitter.h", + "ir_emitter_context.h", + ], + deps = [ + ":elemental_ir_emitter", + ":gpu_executable", + ":hlo_to_ir_bindings", + ":ir_emission_utils", + ":parallel_loop_emitter", + ":partition_assignment", + ":temp_buffer_offsets", + ":while_transformer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:core", + "@llvm//:support", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + ":partition_assignment", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "elemental_ir_emitter", + srcs = ["elemental_ir_emitter.cc"], + hdrs = ["elemental_ir_emitter.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + "@llvm//:support", + ], +) + +cc_library( + name = "temp_buffer_offsets", + srcs = ["temp_buffer_offsets.cc"], + hdrs = ["temp_buffer_offsets.h"], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "buffer_allocations", + srcs = ["buffer_allocations.cc"], + hdrs = ["buffer_allocations.h"], + deps = [ + ":temp_buffer_offsets", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "gpu_executable", + srcs = [ + "convolution_thunk.cc", + "copy_thunk.cc", + "for_thunk.cc", + "gemm_thunk.cc", + "gpu_executable.cc", + "kernel_thunk.cc", + "sequential_thunk.cc", + "thunk_schedule.cc", + "tuple_thunk.cc", + "while_thunk.cc", + ], + hdrs = [ + "convolution_thunk.h", + "copy_thunk.h", + "for_thunk.h", + "gemm_thunk.h", + "gpu_executable.h", + "kernel_thunk.h", + "sequential_thunk.h", + "thunk.h", + "thunk_schedule.h", + "tuple_thunk.h", + "while_thunk.h", + ], + deps = [ + ":buffer_allocations", + ":partition_assignment", + ":stream_assignment", + ":temp_buffer_offsets", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:convolution_thunk_flags", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + ], +) + +cc_library( + name = "ir_emission_utils", + srcs = ["ir_emission_utils.cc"], + hdrs = ["ir_emission_utils.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "convolution_folding", + srcs = ["convolution_folding.cc"], + hdrs = ["convolution_folding.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "convolution_folding_test", + srcs = ["convolution_folding_test.cc"], + deps = [ + ":convolution_folding", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "instruction_fusion", + srcs = ["instruction_fusion.cc"], + hdrs = ["instruction_fusion.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:instruction_fusion", + ], +) + +cc_test( + name = "instruction_fusion_test", + srcs = ["instruction_fusion_test.cc"], + deps = [ + ":instruction_fusion", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "copy_insertion", + srcs = ["copy_insertion.cc"], + hdrs = ["copy_insertion.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "pad_insertion", + srcs = ["pad_insertion.cc"], + hdrs = ["pad_insertion.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:shape_inference", + ], +) + +cc_library( + name = "gpu_compiler", + srcs = ["gpu_compiler.cc"], + hdrs = ["gpu_compiler.h"], + deps = [ + ":convolution_folding", + ":copy_insertion", + ":gpu_executable", + ":hlo_schedule", + ":instruction_fusion", + ":ir_emission_utils", + ":ir_emitter", + ":layout_assignment", + ":pad_insertion", + ":partition_assignment", + ":stream_assignment", + ":temp_buffer_offsets", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:gpu_compiler_flags", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:buffer_liveness", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:core", + "@llvm//:support", + ], + alwayslink = True, # Contains compiler registration +) + +cc_library( + name = "layout_assignment", + srcs = ["layout_assignment.cc"], + hdrs = ["layout_assignment.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "layout_assignment_test", + srcs = ["layout_assignment_test.cc"], + deps = [ + ":layout_assignment", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_schedule", + srcs = ["hlo_schedule.cc"], + hdrs = ["hlo_schedule.h"], + deps = [ + ":stream_assignment", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_liveness", + "//tensorflow/compiler/xla/service:hlo", + ], +) + +cc_test( + name = "hlo_schedule_test", + srcs = [ + "hlo_schedule_test.cc", + ], + deps = [ + ":hlo_schedule", + ":stream_assignment", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "while_transformer", + srcs = ["while_transformer.cc"], + hdrs = ["while_transformer.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "while_transformer_test", + srcs = ["while_transformer_test.cc"], + deps = [ + ":while_transformer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc new file mode 100644 index 0000000000..a9975de3f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, + se::DeviceMemoryBase address) { + InsertOrDie(®istered_buffers_, index, address); +} + +StatusOr> BufferAllocations::Builder::Build( + const BufferAssignment& buffer_assignment, + const TempBufferOffsets& temp_buffer_offsets, int device_ordinal, + DeviceMemoryAllocator* memory_allocator) { + se::DeviceMemoryBase temp_buffer_base; + if (temp_buffer_offsets.TotalSizeInBytes() > 0) { + TF_ASSIGN_OR_RETURN( + temp_buffer_base, + memory_allocator->Allocate(device_ordinal, + temp_buffer_offsets.TotalSizeInBytes())); + if (temp_buffer_base == nullptr) { + return ResourceExhausted( + "Out of memory when allocating %s bytes for temporary buffers.", + tensorflow::strings::HumanReadableNumBytes( + temp_buffer_offsets.TotalSizeInBytes()) + .c_str()); + } + } + auto buffer_allocations = WrapUnique(new BufferAllocations( + buffer_assignment.Allocations().size(), temp_buffer_base, device_ordinal, + memory_allocator)); + + int64 num_buffers = buffer_assignment.Allocations().size(); + for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + // If buffer #i's address is already registered (e.g. external arguments or + // result buffers), use that registered buffer. + if (registered_buffers_.count(i)) { + buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); + continue; + } + + const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + if (allocation.maybe_live_out()) { + auto buffer_size = allocation.size(); + se::DeviceMemoryBase buffer_address; + if (buffer_size > 0) { + // If the buffer escapes, we need to allocate it separately instead of + // merging it into the memory block for temporary buffers. + TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( + device_ordinal, buffer_size)); + if (buffer_address == nullptr) { + return ResourceExhausted( + "Out of memory when allocating %s for buffer %lld.", + tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), + i); + } + } + buffer_allocations->SetBuffer(i, buffer_address); + } else if (allocation.IsPreallocatedTempBuffer()) { + se::DeviceMemoryBase temp_buffer_address( + /*opaque=*/static_cast( + buffer_allocations->GetTempBufferBase().opaque()) + + temp_buffer_offsets.GetOffset(i), + /*size=*/allocation.size()); + buffer_allocations->SetBuffer(i, temp_buffer_address); + } + } + + return std::move(buffer_allocations); +} + +tensorflow::Status BufferAllocations::TearDown( + const std::set& live_addresses, + const BufferAssignment& buffer_assignment) { + // Deallocate temporary buffers. + for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) { + const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); + if (allocation.maybe_live_out() && !live_addresses.count(buffer_address)) { + // Deallocate buffers that marked "maybe_live_out" but is not actually + // live out. + TF_RETURN_IF_ERROR( + memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); + } + } + + // Deallocate the memory block for temporary buffers. + if (temp_buffer_base_ != nullptr) { + TF_RETURN_IF_ERROR( + memory_allocator_->Deallocate(device_ordinal_, &temp_buffer_base_)); + } + return tensorflow::Status::OK(); +} + +se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( + BufferAllocation::Index buffer_index) const { + CHECK_GE(buffer_index, 0); + CHECK_LT(buffer_index, buffers_.size()); + return buffers_[buffer_index]; +} + +void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, + se::DeviceMemoryBase buffer) { + CHECK_GE(buffer_index, 0); + CHECK_LT(buffer_index, buffers_.size()); + buffers_[buffer_index] = buffer; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h new file mode 100644 index 0000000000..a0cd6cac01 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A thread-compatible class that encapsulates the base addresses of the +// allocated device buffers. +class BufferAllocations { + public: + // This inner class encapsulates methods that build a BufferAllocations from + // the given buffer assignment. + class Builder { + public: + // Registers preallocated buffers (such as parameter addresses and + // user-specified result buffers) to the given buffer index. The builder + // will skip allocating buffers for registered buffer indices. + void RegisterBuffer(BufferAllocation::Index index, + perftools::gputools::DeviceMemoryBase address); + + // Builds a BufferAllocations object from the given buffer assignment. + // `memory_allocator` is what this function uses to allocate device memory. + // `device_ordinal` is the number of the device this function allocates + // memory on. + StatusOr> Build( + const BufferAssignment& buffer_assignment, + const TempBufferOffsets& temp_buffer_offsets, int device_ordinal, + DeviceMemoryAllocator* memory_allocator); + + private: + std::map + registered_buffers_; + }; + + BufferAllocations(const BufferAllocations&) = delete; + BufferAllocations& operator=(const BufferAllocations&) = delete; + + DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_; } + int device_ordinal() const { return device_ordinal_; } + + // Returns the device address of buffer `buffer_index`. `buffer_index` must be + // a valid index, i.e., in [0, buffer_count). This function returns null if + // `buffer_index` is not assigned to a buffer address. + perftools::gputools::DeviceMemoryBase GetDeviceAddress( + BufferAllocation::Index buffer_index) const; + + perftools::gputools::DeviceMemoryBase GetTempBufferBase() const { + return temp_buffer_base_; + } + + // Tears down all buffers allocated by this object that are not in + // `live_addresses`. + tensorflow::Status TearDown( + const std::set& live_addresses, + const BufferAssignment& buffer_assignment); + + private: + BufferAllocations(BufferAllocation::Index buffer_count, + perftools::gputools::DeviceMemoryBase temp_buffer_base, + int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : buffers_(buffer_count), + temp_buffer_base_( + perftools::gputools::DeviceMemory(temp_buffer_base)), + device_ordinal_(device_ordinal), + memory_allocator_(memory_allocator) {} + + // Sets the device address of buffer `buffer_index`. + void SetBuffer(BufferAllocation::Index buffer_index, + perftools::gputools::DeviceMemoryBase buffer); + + // An array of device pointers that stores the address of each buffer + // indexed by Index. Each element can point to a temporary buffer, an + // input buffer, or nullptr if no buffer is needed for that Index. + std::vector buffers_; + + // The base address of the memory block that contains all temporary buffers. + perftools::gputools::DeviceMemory temp_buffer_base_; + + int device_ordinal_; + + DeviceMemoryAllocator* memory_allocator_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc new file mode 100644 index 0000000000..dd1b09c6cc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -0,0 +1,443 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +namespace { +// Try to match a backward filter pattern that contains "conv". +// Precondition: "conv" is a kConvolution. +std::tuple, Window, + ConvolutionDimensionNumbers> +MatchBackwardFilter(HloInstruction* conv) { + const auto no_match_result = + std::make_tuple(false, std::vector(), Window(), + ConvolutionDimensionNumbers()); + // Step 1: match the instruction pattern without considering the paddings and + // dimension numbers just yet. We may need some generic pattern matcher + // similar to external/llvm/include/llvm/IR/PatternMatch.h + // + // Backward filter convolution is implemented in XLA as the forward + // convolution of padded activations and dilated gradients. Padding on + // activations and dilation on gradients are specified in the "window" field + // of the forward convolution. + // + // activations gradients + // \ / + // v v + // Convolution + // conv + // | + // v + // Transpose (optional if identity transposition) + CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); + // If the forward convolution is followed by a transpose, we can fuse the + // transpose into the backward convolution as well. + HloInstruction* transpose = nullptr; + if (conv->user_count() == 1) { + HloInstruction* single_user = *conv->users().begin(); + if (single_user->opcode() == HloOpcode::kTranspose) { + transpose = single_user; + } + } + + // Step 2: match paddings and dimension numbers of the forward convolution. + const ConvolutionDimensionNumbers& conv_dnums = + conv->convolution_dimension_numbers(); + auto batch_dim = conv_dnums.batch_dimension(); + auto feature_dim = conv_dnums.feature_dimension(); + auto spatial_dims = conv_dnums.spatial_dimensions(); + + for (const WindowDimension& window_dim : conv->window().dimensions()) { + if (window_dim.stride() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have stride of 1."; + return no_match_result; + } + if (window_dim.base_dilation() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have no base (LHS) dilation."; + return no_match_result; + } + if (window_dim.padding_low() < 0) { + VLOG(1) << "Padding low should be non-negative."; + return no_match_result; + } + // Padding high will be checked in Step 3. + } + if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + VLOG(1) << conv->ToString() + << " is a regular forward convolution. No need " + "to fold it to a backward filter convolution."; + return no_match_result; + } + + // Step 3: fuse the matched HLOs into a backward convolution instruction. + // + // Compute the window of the backward convolution. + Window backward_conv_window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = backward_conv_window.add_dimensions(); + // The window size of the backward convolution equals the output size of the + // forward convolution. + int64 filter_size = conv->shape().dimensions(spatial_dims[i]); + dim->set_size(filter_size); + // The window stride equals the window dilation of the forward convolution. + dim->set_stride(conv->window().dimensions(i).window_dilation()); + // The window's low padding is the same as the low padding of the + // activations. + dim->set_padding_low(conv->window().dimensions(i).padding_low()); + + int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + int64 output_size = conv->window().dimensions(i).size(); + // Compute the range of the amount of valid high padding. We first compute + // min_padding_high, the amount of padding on the right/bottom to ensure the + // last patch ends at the border, i.e., + // + // input_size + dim->padding_low() + min_padding_high + // = (output_size - 1) * stride + filter_size + // + // Because convolution ignores trailing incomplete windows, any amount of + // padding high from min_padding_high to min_padding_high+stride-1 + // (max_padding_high) has the same effect. + int64 padded_input_size = filter_size + (output_size - 1) * dim->stride(); + int64 min_padding_high = + padded_input_size - input_size - dim->padding_low(); + int64 max_padding_high = min_padding_high + dim->stride() - 1; + CHECK_GE(dim->padding_low(), 0); + // In practice, since cuDNN convolution only supports even padding, we make + // the amount of high padding the same as the amount of low padding as long + // as it is between min_padding_high and max_padding_high. If it is not in + // that range, we pick the one that's closest to dim->padding_low() and let + // PadInsertion canonicalize the resultant backward convolution later. + // Picking the closest one minimizes the cost of the kPad instruction to be + // inserted by PadInsertion. + if (dim->padding_low() >= min_padding_high && + dim->padding_low() <= max_padding_high) { + dim->set_padding_high(dim->padding_low()); + } else { + if (dim->padding_low() < min_padding_high) { + dim->set_padding_high(min_padding_high); + } else { + dim->set_padding_high(max_padding_high); + } + } + if (dim->padding_high() < 0) { + LOG(ERROR) + << "Fusing this pattern to backward filter convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the weight gradients, which is not " + "supported by PadInsertion (b/32744257). Falling back to " + "unfused convolution for instruction: " + << conv->ToString(); + return no_match_result; + } + } + + // To make future HLO passes easier, we canonicalize the fused expression by + // adding an identity transposition if it's omitted in the pattern. + if (transpose == nullptr) { + // Create an identity transposition with the same rank as the forward + // convolution. + HloComputation* parent_computation = conv->parent(); + std::vector transpose_dimensions(ShapeUtil::Rank(conv->shape())); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); + transpose = + parent_computation->AddInstruction(HloInstruction::CreateTranspose( + conv->shape(), conv, transpose_dimensions)); + parent_computation->ReplaceUsesOfInstruction(conv, transpose); + } + + // Restore the dimension numbers of the backward convolution from the forward + // convolution. The two activation dimensions are reversed (batch and + // feature). + ConvolutionDimensionNumbers backward_conv_dnums; + backward_conv_dnums.set_batch_dimension(feature_dim); + backward_conv_dnums.set_feature_dimension(batch_dim); + for (int i = 0; i < 2; ++i) { + backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); + } + // The dimension numbering of the output of the forward convolution (before + // transposition) is the same as that of the activations (according to the + // semantics of kConvolution). The batch dimension of the activations should + // be treated as the input feature dimension, and the feature dimension should + // be treated as the output feature. + // + // The output of the forward convolution needs to be transposed to fit into + // the dimension numbering of the weight gradients. This transposition maps + // dimension i to PositionInContainer(transpose->dimensions(), i). + backward_conv_dnums.set_kernel_input_feature_dimension( + PositionInContainer(transpose->dimensions(), batch_dim)); + backward_conv_dnums.set_kernel_output_feature_dimension( + PositionInContainer(transpose->dimensions(), feature_dim)); + for (int i = 0; i < 2; ++i) { + backward_conv_dnums.add_kernel_spatial_dimensions( + PositionInContainer(transpose->dimensions(), spatial_dims[i])); + } + + return std::make_tuple(true, std::vector({transpose, conv}), + backward_conv_window, backward_conv_dnums); +} + +// Try to match a backward input pattern that contains "conv". +// Precondition: "conv" is a kConvolution. +std::tuple, Window, + ConvolutionDimensionNumbers> +MatchBackwardInput(HloInstruction* conv) { + const auto no_match_result = + std::make_tuple(false, std::vector(), Window(), + ConvolutionDimensionNumbers()); + + // Match instruction pattern. + CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); + HloInstruction* reverse_filter = conv->mutable_operand(1); + + // Match the reverse of the filter. + ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); + const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); + if (reverse_filter->opcode() == HloOpcode::kReverse) { + if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || + !std::is_permutation(kernel_spatial_dims.begin(), + kernel_spatial_dims.end(), + reverse_filter->dimensions().begin())) { + VLOG(1) + << "Backward input convolution should reverse all kernel dimensions."; + return no_match_result; + } + } else { + // Possibly 1x1 filter. + for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { + if (conv->window().dimensions(i).size() != 1) { + VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " + << reverse_filter->ToString(); + return no_match_result; + } + } + if (!window_util::HasBaseDilation(conv->window())) { + VLOG(1) << conv->ToString() + << " is a regular forward convolution. No need " + "to fold it to a backward input convolution."; + return no_match_result; + } + } + + // Match padding and dilation of the forward convolution. + for (const WindowDimension& window_dim : conv->window().dimensions()) { + if (window_dim.stride() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have stride of 1."; + return no_match_result; + } + if (window_dim.window_dilation() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have no window dilation."; + return no_match_result; + } + } + + const auto& spatial_dims = dnums.spatial_dimensions(); + CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size()); + + const Window& old_window = conv->window(); + Window new_window = old_window; + for (size_t i = 0; i < spatial_dims.size(); ++i) { + // Restore backward convolution's padding config from the matched pattern. + // See the comment in tensorflow/core/kernels/conv_grad_ops.cc + // for how we convert backward input convolution to a variant of forward + // convolution. + // + // The stride of the backward convolution + // = the base dilation factor of the forward convolution + auto dim = new_window.mutable_dimensions(i); + dim->set_stride(old_window.dimensions(i).base_dilation()); + + // The low padding = kernel_size - 1 - low padding on the gradients + // Make sure the low padding is not negative. + auto kernel_size = old_window.dimensions(i).size(); + auto backward_padding_low = + kernel_size - 1 - old_window.dimensions(i).padding_low(); + if (backward_padding_low < 0) { + LOG(ERROR) + << "The low padding of the backward convolution would be negative (" + << backward_padding_low + << "), which isn't supported by PadInsertion for now (b/32744257)."; + return no_match_result; + } + dim->set_padding_low(backward_padding_low); + + // Compute the range of the amount of padding on the right/bottom of the + // activations. XLA's convolution requires all patches to be within the + // padded base. This gives us flexiblity to choose the amount of high + // padding from a set of values without changing the result of the backward + // convolution. The minimum amount (min_padding_high) makes the last patch + // end at the border. The maximum amount (max_padding_high) equals + // min_padding_high+stride-1 -- max_padding_high+1 would cause the output + // size to change. + auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]); + auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); + auto total_pad_size = padded_input_size - unpadded_input_size; + auto min_padding_high = total_pad_size - backward_padding_low; + auto max_padding_high = min_padding_high + dim->stride() - 1; + + if (backward_padding_low >= min_padding_high && + backward_padding_low <= max_padding_high) { + // In the best case (most likely), if backward_padding_low is in the range + // of the amounts of valid high padding, we choose backward_padding_low + // because cuDNN supports even padding only. + dim->set_padding_high(backward_padding_low); + } else { + // Otherwise, we choose the amount that's closest to backward_padding_low, + // and PadInsertion will later insert kSlice instructions to enforce even + // padding. + // + // For example, consider the backward convolution pattern + // + // ab xy + // | pad | reverse + // .a.b yx + // \ / + // ABC + // + // The amount of low padding on activations (in backward convolution) is + // backward_padding_low = kernel_size - 1 - forward_padding_low + // = 2 - 1 - 1 = 0 + // + // The amount of padding high must be between 1 and 2, in order to make + // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in + // the range of [1,2], so we pick the closest valid amount of padding + // high, which is 1 in this case. Therefore, we fuse the above pattern to + // + // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1) + if (backward_padding_low < min_padding_high) { + dim->set_padding_high(min_padding_high); + } else { + dim->set_padding_high(max_padding_high); + } + } + // PadInsertion doesn't handle backward input convolution with negative + // padding for now. So fall back to unfused convolution in case of negative + // padding. For example, + // ABCD = Conv(abc, reverse(xy), padding_high=2) + // could be fused to + // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) + // with positive padding low but negative padding high. + if (dim->padding_high() < 0) { + LOG(ERROR) << "Fusing this pattern to backward convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the activations, which is not " + "supported by PadInsertion (b/32744257). Falling back to " + "unfused convolution for instruction: " + << conv->ToString(); + return no_match_result; + } + } + + // Fuse the matched HLOs into a backward convolution instruction. + // + // If the reverse is omitted (for 1x1 filters) in the original pattern, we add + // it back in the fusion instruction so that later passes (such as + // PadInsertion) can handle such fusion instructions easily. + if (reverse_filter->opcode() != HloOpcode::kReverse) { + reverse_filter = reverse_filter->parent()->AddInstruction( + HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, + AsInt64Slice(kernel_spatial_dims))); + conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter); + } + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + return std::make_tuple(true, + std::vector({conv, reverse_filter}), + new_window, dnums); +} +} // namespace + +StatusOr ConvolutionFolding::Run(HloModule* module) { + HloComputation* entry_computation = module->entry_computation(); + std::vector convs; + for (const auto& hlo : entry_computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo.get()); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + bool match; + std::vector hlos_to_fuse; + Window window; + ConvolutionDimensionNumbers dnums; + std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv); + if (match) { + VLOG(2) << "Fuse instructions"; + for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { + VLOG(2) << " " << hlo_to_fuse->ToString(); + } + HloInstruction* backward_convolution = + entry_computation->CreateFusionInstructionForBackwardConvolution( + hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter, + window, dnums); + VLOG(2) << "to backward filter convolution"; + VLOG(2) << " " << backward_convolution->ToString(); + changed = true; + continue; + } + + std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv); + if (match) { + VLOG(2) << "Fuse instructions"; + for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { + VLOG(2) << " " << hlo_to_fuse->ToString(); + } + HloInstruction* backward_convolution = + entry_computation->CreateFusionInstructionForBackwardConvolution( + hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput, + window, dnums); + VLOG(2) << "to backward input convolution"; + VLOG(2) << " " << backward_convolution->ToString(); + changed = true; + continue; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.h b/tensorflow/compiler/xla/service/gpu/convolution_folding.h new file mode 100644 index 0000000000..e0233228c7 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { +namespace gpu { + +class ConvolutionFolding : public HloPass { + public: + ConvolutionFolding() : HloPass("convolution-folding") {} + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc new file mode 100644 index 0000000000..83922cbe14 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -0,0 +1,552 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class ConvolutionFoldingTest : public HloTestBase { + public: + ConvolutionFoldingTest() { + for (int i = 0; i < 2; ++i) { + WindowDimension* window_dim = default_conv_window_.add_dimensions(); + window_dim->set_size(1); + window_dim->set_stride(1); + window_dim->set_padding_low(0); + window_dim->set_padding_high(0); + window_dim->set_window_dilation(1); + window_dim->set_base_dilation(1); + } + // TF data shapes are by default in the NHWC order, and filter shape is by + // default in HWIO order. For backward filter convolution, we need to swap + // the batch and feature dimension in the activations, and treat the batch + // dimension in gradients as the input feature dimension in the filter. + // + // TODO(jingyue): Add more tests on NCHW input order which TF also supports. + tf_default_dnums_for_backward_filter_.set_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_feature_dimension(0); + tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( + 3); + tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + + tf_default_dnums_for_backward_input_.set_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_feature_dimension(3); + tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); + tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); + tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1); + } + + protected: + bool FoldConvolution(HloModule* module) { + ConvolutionFolding convolution_folding; + return convolution_folding.Run(module).ValueOrDie(); + } + + // A convolution window with stride 1 and zero padding. The size fields are + // not set. + Window default_conv_window_; + ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_; + ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; +}; + +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { + HloComputation::Builder builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients")); + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_size(2); + conv_window.mutable_dimensions(1)->set_window_dilation(2); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(activations->shape(), + gradients->shape(), conv_window, + tf_default_dnums_for_backward_filter_) + .ConsumeValueOrDie(), + activations, gradients, conv_window, + tf_default_dnums_for_backward_filter_)); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == + entry_computation->root_instruction()->fusion_kind()); +} + +TEST_F(ConvolutionFoldingTest, + BackwardFilterConvolveEquivalentToForwardConvolution) { + HloComputation::Builder builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients")); + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_size(3); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(activations->shape(), + gradients->shape(), conv_window, + tf_default_dnums_for_backward_filter_) + .ConsumeValueOrDie(), + activations, gradients, conv_window, + tf_default_dnums_for_backward_filter_)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(&module)); +} + +// Extracted from block35 training. +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(35); + conv_window.mutable_dimensions(i)->set_padding_low(1); + conv_window.mutable_dimensions(i)->set_padding_high(1); + } + HloInstruction* convolution = + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); + + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == + entry_computation->root_instruction()->fusion_kind()); +} + +// Extracted from inception v3 training. +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients")); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(4); + conv_window.mutable_dimensions(i)->set_padding_high(-1); + conv_window.mutable_dimensions(i)->set_window_dilation(2); + } + HloInstruction* convolution = + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); + + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == + entry_computation->root_instruction()->fusion_kind()); +} + +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(35); + // Uneven padding: padding_low=0, padding_high=1 + conv_window.mutable_dimensions(i)->set_padding_high(1); + } + HloInstruction* convolution = + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); + + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == + entry_computation->root_instruction()->fusion_kind()); +} + +TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output")); + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3})); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(7); + conv_window.mutable_dimensions(i)->set_padding_low(3); + conv_window.mutable_dimensions(i)->set_padding_high(3); + } + ConvolutionDimensionNumbers conv_dnums; + conv_dnums.set_batch_dimension(0); + conv_dnums.set_feature_dimension(1); + conv_dnums.add_spatial_dimensions(2); + conv_dnums.add_spatial_dimensions(3); + conv_dnums.set_kernel_input_feature_dimension(0); + conv_dnums.set_kernel_output_feature_dimension(1); + conv_dnums.add_kernel_spatial_dimensions(2); + conv_dnums.add_kernel_spatial_dimensions(3); + + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, + /*rhs=*/reverse_kernel, conv_window, conv_dnums)); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) + .ValueOrDie())); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == + entry_computation->root_instruction()->fusion_kind()); + for (int i = 0; i < 2; ++i) { + const WindowDimension& window_dim = + entry_computation->root_instruction()->window().dimensions(i); + // Low padding of the backward input convolution + // = kernel_size - 1 - low padding on gradients. + EXPECT_EQ(3, window_dim.padding_low()); + EXPECT_EQ(3, window_dim.padding_high()); + EXPECT_EQ(1, window_dim.stride()); + } +} + +// Convolve([abc], [x], base_dilation=2) +// = Convolve([abc], Reverse([x]), base_dilation=2) +// = BackwardInputConvolve([abc], [x], stride=2) +TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { + auto builder = HloComputation::Builder(TestName()); + // NHWC dimension order. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); + // HWOI dimension order. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); + + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_base_dilation(2); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), + conv_window, + tf_default_dnums_for_backward_input_) + .ConsumeValueOrDie(), + /*lhs=*/output, /*rhs=*/kernel, conv_window, + tf_default_dnums_for_backward_input_)); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == + entry_computation->root_instruction()->fusion_kind()); +} + +// BackwardInputConvolve([abc], [x], stride=1) is equivalent to +// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input +// convolution. +TEST_F(ConvolutionFoldingTest, + BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { + auto builder = HloComputation::Builder(TestName()); + // NHWC dimension order. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); + // HWOI dimension order. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), + default_conv_window_, + tf_default_dnums_for_backward_input_) + .ConsumeValueOrDie(), + /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, + tf_default_dnums_for_backward_input_)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(&module)); +} + +// Extracted from Inception V3 training. +// +// filter(HWIO) +// 3x3x192x320 +// | +// v +// gradients(NHWC) reverse +// 20x4x4x320 3x3x192x320 +// \ / +// \ / +// conv (NHWC) with padding (low=2,high=3,interior=1) +// 20x10x10x192 +// +// Gradients are padded unevenly. +TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(3); + conv_window.mutable_dimensions(i)->set_padding_low(2); + conv_window.mutable_dimensions(i)->set_padding_high(3); + // Interior padding = 1. + conv_window.mutable_dimensions(i)->set_base_dilation(2); + } + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, + conv_window, tf_default_dnums_for_backward_input_)); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), conv_window, + tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + EXPECT_EQ(HloOpcode::kFusion, + entry_computation->root_instruction()->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == + entry_computation->root_instruction()->fusion_kind()); + for (int i = 0; i < 2; ++i) { + const WindowDimension& window_dim = + entry_computation->root_instruction()->window().dimensions(i); + EXPECT_EQ(0, window_dim.padding_low()); + EXPECT_EQ(0, window_dim.padding_high()); + EXPECT_EQ(2, window_dim.stride()); + } +} + +// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the +// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. +TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(3); + conv_window.mutable_dimensions(i)->set_padding_low(3); + conv_window.mutable_dimensions(i)->set_padding_high(2); + conv_window.mutable_dimensions(i)->set_base_dilation(2); + } + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, + conv_window, tf_default_dnums_for_backward_input_)); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), conv_window, + tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(&module)); +} + +// Extracted from //learning/brain/google/xla/benchmarks/resnet.py +// +// For simplicity, we focus on the column dimension and ignore other dimensions. +// We use [?] to represent the shape instead of the content. +// +// Suppose operator FC does +// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame +// +// BC = BackwardInput(FC) does: +// [14] = conv([7], reverse([3]), +// padding_low=2, padding_high=1, base_dilation=2) +// +// We should fuse BC even though padding on activations is uneven, because +// PadInsertion will canonicalize the fusion HLO. +TEST_F(ConvolutionFoldingTest, + BackwardInputConvolveUnevenPaddingOnActivations) { + auto builder = HloComputation::Builder(TestName()); + // The gradients are in NCHW layout. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output")); + // The kernel is in HWIO layout. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); + forward_conv_col_dim->set_size(3); + forward_conv_col_dim->set_padding_low(2); + forward_conv_col_dim->set_padding_high(1); + forward_conv_col_dim->set_base_dilation(2); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, + conv_window, tf_default_dnums_for_backward_input_)); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), conv_window, + tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + HloModule module(TestName()); + const HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(&module)); + const HloInstruction* backward_conv = entry_computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode()); + EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == + backward_conv->fusion_kind()); + const WindowDimension& backward_conv_col_dim = + backward_conv->window().dimensions(1); + EXPECT_EQ(0, backward_conv_col_dim.padding_low()); + EXPECT_EQ(1, backward_conv_col_dim.padding_high()); +} + +// For simplicity, we focus on the column dimension and ignore other dimensions. +// We use [?] to represent the shape instead of the content. +// +// Suppose operator FC does +// [3] = conv([4], [2], padding_low=1, padding_high=-1) +// +// BC = BackwardInput(FC) does: +// [4] = conv([3], reverse([2]), padding_high=2) +// +// We currently don't fuse BC because PadInsertion doesn't support negative +// padding on the gradients of backward convolution (b/32744257). +TEST_F(ConvolutionFoldingTest, + BackwardInputConvolveNegativePaddingHighOnActivations) { + auto builder = HloComputation::Builder(TestName()); + // The gradients are in NCHW layout. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); + // The kernel is in HWIO layout. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); + forward_conv_col_dim->set_size(2); + forward_conv_col_dim->set_padding_high(2); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, + conv_window, tf_default_dnums_for_backward_input_)); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), conv_window, + tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(&module)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc new file mode 100644 index 0000000000..30a92ab313 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -0,0 +1,324 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" + +#include + +#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +using Index = BufferAllocation::Index; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; + +ConvolveScratchAllocator::ConvolveScratchAllocator( + int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + +ConvolveScratchAllocator::~ConvolveScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { + constexpr int64 kConvolveScratchSize = 1LL << 32; // 4GB by default. + return kConvolveScratchSize; +} + +se::port::StatusOr> +ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Failed to allocate %lld bytes on device %d.", + byte_size, device_ordinal_)); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +string ConvolutionKindToString( + ConvolutionThunk::ConvolutionKind convolution_kind) { + switch (convolution_kind) { + case ConvolutionThunk::ConvolutionKind::kForward: + return "forward"; + case ConvolutionThunk::ConvolutionKind::kBackwardFilter: + return "backward_filter"; + case ConvolutionThunk::ConvolutionKind::kBackwardInput: + return "backward_input"; + } +} + +ConvolutionThunk::ConvolutionThunk( + ConvolutionKind convolution_kind, Index input_buffer, Index filter_buffer, + Index output_buffer, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo) + : Thunk(Kind::kConvolution, hlo), + convolution_kind_(convolution_kind), + input_buffer_(input_buffer), + filter_buffer_(filter_buffer), + output_buffer_(output_buffer), + input_shape_(input_shape), + filter_shape_(filter_shape), + output_shape_(output_shape), + window_(window), + dim_nums_(dim_nums) {} + +tensorflow::Status ConvolutionThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_); + VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }"; + VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }"; + VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }"; + VLOG(3) << "Window: { " << window_.ShortDebugString() << " }"; + + CHECK_EQ(F32, output_shape_.element_type()); + CHECK_EQ(2, window_.dimensions_size()); + for (const WindowDimension& dim : window_.dimensions()) { + CHECK_EQ(dim.padding_low(), dim.padding_high()); + } + + const WindowDimension& height = window_.dimensions(0); + const WindowDimension& width = window_.dimensions(1); + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls + // when we switch to cuDNN v5. + BatchDescriptor input_descriptor; + input_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_height(input_shape_.dimensions(dim_nums_.spatial_dimensions(0))) + .set_width(input_shape_.dimensions(dim_nums_.spatial_dimensions(1))) + .set_feature_map_count( + input_shape_.dimensions(dim_nums_.feature_dimension())) + .set_count(input_shape_.dimensions(dim_nums_.batch_dimension())); + + FilterDescriptor filter_descriptor; + filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + .set_input_feature_map_count( + filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape_.dimensions(dim_nums_.kernel_output_feature_dimension())) + .set_input_filter_height( + filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(0))) + .set_input_filter_width( + filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(1))); + + ConvolutionDescriptor convolution_descriptor; + convolution_descriptor.set_zero_padding_width(width.padding_low()) + .set_zero_padding_height(height.padding_low()) + .set_horizontal_filter_stride(width.stride()) + .set_vertical_filter_stride(height.stride()); + + BatchDescriptor output_descriptor; + output_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_height(output_shape_.dimensions(dim_nums_.spatial_dimensions(0))) + .set_width(output_shape_.dimensions(dim_nums_.spatial_dimensions(1))) + .set_feature_map_count( + output_shape_.dimensions(dim_nums_.feature_dimension())) + .set_count(output_shape_.dimensions(dim_nums_.batch_dimension())); + + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory filter_data( + buffer_allocations.GetDeviceAddress(filter_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + return ConvolveWithTune(input_descriptor, input_data, filter_descriptor, + filter_data, output_descriptor, output_data, + convolution_descriptor, buffer_allocations, stream); +} + +tensorflow::Status ConvolutionThunk::Convolve( + const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, + const FilterDescriptor& filter_descriptor, + se::DeviceMemory filter_data, + const BatchDescriptor& output_descriptor, + se::DeviceMemory output_data, + const ConvolutionDescriptor& convolution_descriptor, + const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream, + ConvolveScratchAllocator* scratch_allocator, + se::dnn::ProfileResult* profile_result) { + bool launch_ok; + switch (convolution_kind_) { + case ConvolutionKind::kBackwardFilter: + launch_ok = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_data, output_descriptor, output_data, + convolution_descriptor, filter_descriptor, &filter_data, + scratch_allocator, algorithm_config, profile_result) + .ok(); + break; + case ConvolutionKind::kBackwardInput: + launch_ok = stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_data, output_descriptor, + output_data, convolution_descriptor, input_descriptor, + &input_data, scratch_allocator, algorithm_config, + profile_result) + .ok(); + break; + case ConvolutionKind::kForward: + launch_ok = + stream + ->ThenConvolveWithAlgorithm( + input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, &output_data, + scratch_allocator, algorithm_config, profile_result) + .ok(); + break; + } + if (launch_ok) { + return tensorflow::Status::OK(); + } + return InternalError( + "Unable to launch convolution for thunk %p with type %s and algorithm " + "(%lld, %lld)", + this, ConvolutionKindToString(convolution_kind_).c_str(), + algorithm_config.algorithm(), algorithm_config.algorithm_no_scratch()); +} + +std::vector ConvolutionThunk::GetAlgorithms( + se::StreamExecutor* stream_exec) const { + std::vector algorithms; + switch (convolution_kind_) { + case ConvolutionKind::kBackwardFilter: + CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(&algorithms)); + break; + case ConvolutionKind::kBackwardInput: + CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(&algorithms)); + break; + case ConvolutionKind::kForward: + CHECK(stream_exec->GetConvolveAlgorithms(&algorithms)); + break; + } + return algorithms; +} + +tensorflow::Status ConvolutionThunk::ConvolveWithTune( + const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, + const FilterDescriptor& filter_descriptor, + se::DeviceMemory filter_data, + const BatchDescriptor& output_descriptor, + se::DeviceMemory output_data, + const ConvolutionDescriptor& convolution_descriptor, + const BufferAllocations& buffer_allocations, se::Stream* stream) { + // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. + legacy_flags::ConvolutionThunkFlags* flags = + legacy_flags::GetConvolutionThunkFlags(); + if (flags->xla_gpu_autotune_convolution_algorithm && + best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) { + // Auto-tuning either is disabled or only happens in the first run of this + // function. + VLOG(2) << "Profiling for best convolution algorithm used for " + "ConvolutionThunk: " + << this; + + se::dnn::ProfileResult best_result; + se::dnn::ProfileResult best_result_without_scratch; + for (se::dnn::AlgorithmType algorithm : GetAlgorithms(stream->parent())) { + ConvolveScratchAllocator scratch_allocator( + buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + se::dnn::ProfileResult profile_result; + bool launch_ok = + Convolve(input_descriptor, input_data, filter_descriptor, filter_data, + output_descriptor, output_data, convolution_descriptor, + se::dnn::AlgorithmConfig(algorithm, algorithm), stream, + &scratch_allocator, &profile_result) + .ok(); + if (launch_ok && profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalAllocatedBytes() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_without_scratch.elapsed_time_in_ms()) { + best_result_without_scratch = profile_result; + } + } + } + + if (best_result.is_valid()) { + best_algorithm_.set_algorithm(best_result.algorithm()); + } else { + LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " + "to the default algorithm."; + best_algorithm_.set_algorithm(se::dnn::kDefaultAlgorithm); + } + + if (best_result_without_scratch.is_valid()) { + best_algorithm_.set_algorithm_no_scratch( + best_result_without_scratch.algorithm()); + } else { + LOG(ERROR) << "No convolution algorithm without scratch works with " + "profiling. Fall back " + "to the default algorithm."; + best_algorithm_.set_algorithm_no_scratch(se::dnn::kDefaultAlgorithm); + } + } + + { + VLOG(2) << "Using convolution algorithm (" << best_algorithm_.algorithm() + << ", " << best_algorithm_.algorithm_no_scratch() + << ") for ConvolutionThunk: " << this; + ConvolveScratchAllocator scratch_allocator( + buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + return Convolve(input_descriptor, input_data, filter_descriptor, + filter_data, output_descriptor, output_data, + convolution_descriptor, best_algorithm_, stream, + &scratch_allocator, nullptr); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h new file mode 100644 index 0000000000..cd9568f6a2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A one-time scratch allocator for forward and backward convolution. The +// scratch buffers allocated are released on destruction. +// +// Not thread-safe. +class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator { + public: + ConvolveScratchAllocator(int device_ordinal, + DeviceMemoryAllocator* memory_allocator); + + ~ConvolveScratchAllocator() override; + + int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; + + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + perftools::gputools::port::StatusOr> + AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +// This class stores everything that StreamExecutor needs to launch a BNN +// convolution. It is generated by IrEmitter. +// +// This is thread-compatible. +class ConvolutionThunk : public Thunk { + public: + // ConvolutionThunk performs one of the following types of convolution. + enum class ConvolutionKind { + kBackwardFilter, // Backward convolution for filter. + kBackwardInput, // Backward convolution for input. + kForward, // Forward convolution. + }; + + // Constructs a thunk for launching a DNN convolution. + // Semantics of null hlo_instruction argument are as in Thunk. + ConvolutionThunk(ConvolutionKind convolution_kind, + BufferAllocation::Index input_buffer, + BufferAllocation::Index filter_buffer, + BufferAllocation::Index output_buffer, + const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, + const HloInstruction* hlo); + + ConvolutionThunk(const ConvolutionThunk&) = delete; + ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; + + // Does the convolution for the thunk on "stream". If the + // xla_gpu_autotune_convolution_algorithm is turned on, auto-tuning happens on + // the first run of this function. + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + tensorflow::Status ConvolveWithTune( + const perftools::gputools::dnn::BatchDescriptor& input_descriptor, + perftools::gputools::DeviceMemory input_data, + const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, + perftools::gputools::DeviceMemory filter_data, + const perftools::gputools::dnn::BatchDescriptor& output_descriptor, + perftools::gputools::DeviceMemory output_data, + const perftools::gputools::dnn::ConvolutionDescriptor& + convolution_descriptor, + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream); + + tensorflow::Status Convolve( + const perftools::gputools::dnn::BatchDescriptor& input_descriptor, + perftools::gputools::DeviceMemory input_data, + const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, + perftools::gputools::DeviceMemory filter_data, + const perftools::gputools::dnn::BatchDescriptor& output_descriptor, + perftools::gputools::DeviceMemory output_data, + const perftools::gputools::dnn::ConvolutionDescriptor& + convolution_descriptor, + const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, + perftools::gputools::Stream* stream, + ConvolveScratchAllocator* scratch_allocator, + perftools::gputools::dnn::ProfileResult* profile_result); + + // Returns the convolve algorithms that can be used for this ConvolutionThunk. + std::vector GetAlgorithms( + perftools::gputools::StreamExecutor* stream_exec) const; + + // Fastest cuDNN convolution algorithm for this thunk learned from + // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set + // to the default value indicating cuDNN's convolution will choose + // the best algorithm from some heuristics based on its parameters. + perftools::gputools::dnn::AlgorithmConfig best_algorithm_; + + ConvolutionKind convolution_kind_; + + BufferAllocation::Index input_buffer_; + BufferAllocation::Index filter_buffer_; + BufferAllocation::Index output_buffer_; + + Shape input_shape_; + Shape filter_shape_; + Shape output_shape_; + + Window window_; + + ConvolutionDimensionNumbers dim_nums_; +}; + +string ConvolutionKindToString( + ConvolutionThunk::ConvolutionKind convolution_kind); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc new file mode 100644 index 0000000000..926690c5c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +StatusOr GpuCopyInsertion::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module)); + + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // Make sure all operands of a library call are in memory instead of constants + // in IR. The top-level (index {}) of the points-to set of each operand + // indicates the source(s) of the array buffer. If any of these are constant, + // then add a copy to materialize the array. + HloComputation* computation = module->entry_computation(); + for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { + if (ImplementedAsLibraryCall(*hlo)) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* operand = hlo->mutable_operand(i); + const PointsToSet& points_to = + points_to_analysis->GetPointsToSet(operand); + const auto& element = points_to.element(/*index=*/{}); + if (std::any_of(element.begin(), element.end(), + [](const LogicalBuffer* buffer_source) { + return buffer_source->instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + CopyInsertion::FindOrInsertCopy(operand)); + hlo->ReplaceOperandWith(i, copy); + changed = true; + } + } + } + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/copy_insertion.h new file mode 100644 index 0000000000..11077dad2e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace gpu { + +// Besides the modifications made by the generic xla::CopyInsertion, this +// GPU-specific copy insertion also materializes operands of library calls by +// inserting kCopy instructions. +class GpuCopyInsertion : public CopyInsertion { + public: + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc new file mode 100644 index 0000000000..76fb079bd4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" + +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +CopyThunk::CopyThunk(const void* source_address, + BufferAllocation::Index destination_buffer, + uint64 mem_size, const HloInstruction* hlo_instruction) + : Thunk(Kind::kCopy, hlo_instruction), + source_address_(source_address), + destination_buffer_(destination_buffer), + mem_size_(mem_size) {} + +tensorflow::Status CopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + perftools::gputools::DeviceMemoryBase destination_data = + buffer_allocations.GetDeviceAddress(destination_buffer_); + stream->ThenMemcpy(&destination_data, source_address_, mem_size_); + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h new file mode 100644 index 0000000000..803e699bfd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +// A thunk that copies data. For now, it copies data only from host to device. +// But it can be extended to perform device-to-host or intra-device copying. +class CopyThunk : public Thunk { + public: + // Constructs a CopyThunk that copies host data from `source_address` to the + // device buffer `destination_buffer`. `mem_size` is the size of the data in + // bytes. + CopyThunk(const void* source_address, + BufferAllocation::Index destination_buffer, uint64 mem_size, + const HloInstruction* hlo_instruction); + + CopyThunk(const CopyThunk&) = delete; + CopyThunk& operator=(const CopyThunk&) = delete; + + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const void* source_address_; + BufferAllocation::Index destination_buffer_; + uint64 mem_size_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc new file mode 100644 index 0000000000..e318ade5ee --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -0,0 +1,396 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" + +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +// IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "external/llvm/include/llvm/ADT/APInt.h" +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Intrinsics.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IR/Type.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gpu { + +using llvm_ir::IrArray; +using llvm_ir::SetToFirstInsertPoint; + +GpuElementalIrEmitter::GpuElementalIrEmitter( + const HloModuleConfig& hlo_module_config, llvm::Module* module, + llvm::IRBuilder<>* ir_builder, NestedComputer compute_nested) + : ElementalIrEmitter(hlo_module_config, module, ir_builder), + compute_nested_(std::move(compute_nested)) {} + +StatusOr GpuElementalIrEmitter::EmitMathCall( + const string& callee_name, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice input_types, + PrimitiveType output_type) const { + // Binary math functions tranform are of type [T] -> T. + for (PrimitiveType input_type : input_types) { + if (output_type != input_type) { + return Unimplemented("Input type ≠ output type: %s ≠ %s", + PrimitiveType_Name(input_type).c_str(), + PrimitiveType_Name(output_type).c_str()); + } + } + + // The libdevice math functions differentiate between "double" and "float" by + // appending an 'f' to the function's name. + string function_name = callee_name; + switch (output_type) { + case F32: + function_name += 'f'; + break; + case F64: + break; + default: + return Unimplemented("Bad type for math call: %s", + PrimitiveType_Name(output_type).c_str()); + } + + return EmitDeviceFunctionCall( + function_name, operands, input_types, output_type, + {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}); +} + +StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); + PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); + PrimitiveType output_type = op->shape().element_type(); + switch (op->opcode()) { + case HloOpcode::kRemainder: { + return EmitMathCall("__nv_fmod", {lhs_value, rhs_value}, + {lhs_input_type, rhs_input_type}, output_type); + } + case HloOpcode::kPower: { + return EmitMathCall("__nv_pow", {lhs_value, rhs_value}, + {lhs_input_type, rhs_input_type}, output_type); + } + default: + return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value); + } +} + +StatusOr GpuElementalIrEmitter::EmitErfcInv( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType output_type = op->shape().element_type(); + switch (op->opcode()) { + case HloOpcode::kExp: + return EmitMathCall("__nv_exp", {operand_value}, {input_type}, + output_type); + case HloOpcode::kFloor: + return EmitMathCall("__nv_floor", {operand_value}, {input_type}, + output_type); + case HloOpcode::kCeil: + return EmitMathCall("__nv_ceil", {operand_value}, {input_type}, + output_type); + case HloOpcode::kLog: + return EmitMathCall("__nv_log", {operand_value}, {input_type}, + output_type); + case HloOpcode::kTanh: + return EmitMathCall("__nv_tanh", {operand_value}, {input_type}, + output_type); + default: + return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); + } +} + +llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( + const string& callee_name, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice input_types, + PrimitiveType output_type, + tensorflow::gtl::ArraySlice attributes) const { + std::vector ir_input_types; + for (PrimitiveType input_type : input_types) { + ir_input_types.push_back( + llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_)); + } + llvm::FunctionType* callee_type = llvm::FunctionType::get( + llvm_ir::PrimitiveTypeToIrType(output_type, + ir_builder_), // The return type. + ir_input_types, // The parameter types. + false); // No variadic arguments. + + // Declares the callee if it is not declared already. + llvm::Function* callee = llvm::cast( + ir_builder_->GetInsertBlock()->getModule()->getOrInsertFunction( + llvm_ir::AsStringRef(callee_name), callee_type)); + + for (auto attribute : attributes) { + callee->addFnAttr(attribute); + } + + return ir_builder_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); +} + +llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { + llvm::Value* block_id = ir_builder_->CreateIntCast( + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, + {}, {}, ir_builder_), + ir_builder_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = ir_builder_->CreateIntCast( + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, + {}, {}, ir_builder_), + ir_builder_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = ir_builder_->CreateIntCast( + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, + {}, {}, ir_builder_), + ir_builder_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return ir_builder_->CreateNSWAdd( + ir_builder_->CreateNSWMul(block_id, threads_per_block), + thread_id_in_block); +} + +llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const { + switch (hlo->opcode()) { + case HloOpcode::kPad: + return [=, &operand_to_generator]( + const IrArray::Index& padded_index) -> StatusOr { + auto index = padded_index; + llvm::Value* in_bounds = + llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); + for (int i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], + index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + ir_builder_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); + }; + case HloOpcode::kMap: + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_RET_CHECK(!hlo->operands().empty()) + << "Zero operand map not implemented in GPU backend."; + TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0); + std::vector operand_elements; + for (HloInstruction* operand : hlo->operands()) { + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(index)); + operand_elements.push_back(value); + } + return compute_nested_(*hlo->to_apply(), operand_elements); + }; + case HloOpcode::kReduceWindow: + // Pseudocode: + // for each index I in output + // value = init_value + // for each index W in window + // for each dimension i from 0 to rank - 1 + // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] + // if I in bounds of input + // value = function(value, input[I]) + // output[O] = value + return [=, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + const HloInstruction* operand = hlo->operand(0); + const Window& window = hlo->window(); + + // TODO(b/31410564): Implement dilation for reduce-window. + if (window_util::HasDilation(window)) { + return Unimplemented( + "Dilation for reduce-window not implemented on GPU. " + "See b/31410564."); + } + + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_), + "reduce_window_accum_ptr", ir_builder_); + { + TF_ASSIGN_OR_RETURN(llvm::Value * init_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(init_value, accum_ptr); + } + + llvm_ir::ForLoopNest loops(ir_builder_); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); + + IrArray::Index input_index(index.size()); + llvm::Value* in_bounds = ir_builder_->getInt1(1); + for (size_t i = 0; i < index.size(); ++i) { + llvm::Value* stridden_index = ir_builder_->CreateNSWMul( + index[i], ir_builder_->getInt64(window.dimensions(i).stride())); + input_index[i] = ir_builder_->CreateNSWSub( + ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), + ir_builder_->getInt64(window.dimensions(i).padding_low())); + + // We must check whether 0 ≤ input_index[i] < bound, as otherwise + // we are in the pad and so can skip the computation. This + // comparison is equivalent to the unsigned comparison + // input_index[i] < bound, as a negative value wraps to a large + // positive value. + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpULT( + input_index[i], + ir_builder_->getInt64(operand->shape().dimensions(i)))); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + + // We are not in pad, so do the computation. + TF_ASSIGN_OR_RETURN(llvm::Value * input_value, + operand_to_generator.at(operand)(input_index)); + TF_ASSIGN_OR_RETURN( + llvm::Value * accum_value, + compute_nested_(*hlo->to_apply(), + {ir_builder_->CreateLoad(accum_ptr), input_value})); + ir_builder_->CreateStore(accum_value, accum_ptr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder_); + return ir_builder_->CreateLoad(accum_ptr); + }; + case HloOpcode::kReduce: + return [=, &operand_to_generator]( + const IrArray::Index& output_index) -> StatusOr { + const HloInstruction* operand = hlo->operand(0); + llvm::Value* accum_ptr = + ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + hlo->shape().element_type(), ir_builder())); + TF_ASSIGN_OR_RETURN(llvm::Value * init_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder()->CreateStore(init_value, accum_ptr); + + llvm_ir::ForLoopNest loops(ir_builder_); + IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( + operand->shape(), hlo->dimensions(), "reduction_dim"); + if (!ShapeUtil::IsScalar(hlo->shape())) { + // Here only input_index[hlo->dimensions()] are non-null, so we must + // set the rest. + size_t j = 0; + for (size_t i = 0; i < input_index.size(); ++i) { + if (input_index[i] == nullptr) { + input_index[i] = output_index[j++]; + } + } + CHECK_EQ(output_index.size(), j); + } + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder()); + TF_ASSIGN_OR_RETURN( + llvm::Value * input_value, + operand_to_generator.at(hlo->operand(0))(input_index)); + TF_ASSIGN_OR_RETURN( + llvm::Value * accum_value, + compute_nested_( + *hlo->to_apply(), + {ir_builder()->CreateLoad(accum_ptr), input_value})); + ir_builder()->CreateStore(accum_value, accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder()); + return ir_builder()->CreateLoad(accum_ptr); + }; + default: + return ElementalIrEmitter::MakeElementGenerator(hlo, + operand_to_generator); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h new file mode 100644 index 0000000000..8e5bdc59cf --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_ + +#include +#include +#include + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace gpu { + +class GpuElementalIrEmitter : public ElementalIrEmitter { + public: + // A NestedComputer computes an element of the output of the given computation + // given an ArraySlice of its input elements. + using NestedComputer = std::function( + const HloComputation&, tensorflow::gtl::ArraySlice)>; + + GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, + llvm::Module* module, llvm::IRBuilder<>* ir_builder, + NestedComputer compute_nested); + + llvm_ir::ElementGenerator MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const override; + + protected: + StatusOr EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const override; + + StatusOr EmitFloatBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const override; + + StatusOr EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) const override; + + llvm::Value* EmitThreadId() const override; + + private: + // Emit IR to call a device function named "callee_name" on the given + // operand. Returns the IR value that represents the return value. + llvm::Value* EmitDeviceFunctionCall( + const string& callee_name, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice input_type, + PrimitiveType output_type, + tensorflow::gtl::ArraySlice attributes) const; + + // Emit IR to call a device function of type [T] -> T. It adjusts the + // callee_name to account for float/double types. + // Returns the IR value that represents the return value. + StatusOr EmitMathCall( + const string& callee_name, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice input_types, + PrimitiveType output_type) const; + + NestedComputer compute_nested_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc new file mode 100644 index 0000000000..283d21ca22 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/for_thunk.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +ForThunk::ForThunk(const int64 loop_limit, + std::unique_ptr body_thunk_sequence, + const HloInstruction* hlo) + : Thunk(Kind::kWhile, hlo), + loop_limit_(loop_limit), + body_thunk_sequence_( + MakeUnique(std::move(*body_thunk_sequence), hlo)) {} + +tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) { + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ForThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + for (int64 i = 0; i < loop_limit_; ++i) { + // Invoke loop body thunk sequence. + TF_RETURN_IF_ERROR( + body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + } + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h new file mode 100644 index 0000000000..525a2af941 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ + +#include + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'. +class ForThunk : public Thunk { + public: + ForThunk(const int64 loop_limit, + std::unique_ptr body_thunk_sequence, + const HloInstruction* hlo); + ForThunk(const ForThunk&) = delete; + ForThunk& operator=(const ForThunk&) = delete; + + tensorflow::Status Initialize(const GpuExecutable& executable) override; + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const int64 loop_limit_; + std::unique_ptr body_thunk_sequence_; +}; + +} // namespace gpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc new file mode 100644 index 0000000000..98a8a4a2b1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" + +#include + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +using Index = BufferAllocation::Index; + +namespace { + +// This struct contains the metadata of a matrix, e.g., its base address and +// dimensions. +struct MatrixDescriptor { + MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, + int64 matrix_num_rows, int64 matrix_num_cols) + : data(matrix_data), + transpose(needs_transpose), + num_rows(matrix_num_rows), + num_cols(matrix_num_cols) {} + + se::DeviceMemoryBase data; + bool transpose; // Whether this matrix needs to be transposed. + int64 num_rows; + int64 num_cols; +}; + +// Performs a gemm call on lhs_matrix and rhs_matrix and stores the result to +// output_matrix. +template +tensorflow::Status DoGemm(MatrixDescriptor lhs_matrix, + MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { + DCHECK(!output_matrix.transpose); + + se::DeviceMemory lhs_data(lhs_matrix.data); + se::DeviceMemory rhs_data(rhs_matrix.data); + se::DeviceMemory output_data(output_matrix.data); + + bool launch_ok = + stream + ->ThenBlasGemm( + lhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose, + rhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose, + output_matrix.num_rows, output_matrix.num_cols, + lhs_matrix.transpose + ? lhs_matrix.num_rows + : lhs_matrix.num_cols, // Size of the reduce dimension. + /*alpha=*/1.0, + lhs_data, + lhs_matrix.num_rows, // The leading dimension of LHS. + rhs_data, + rhs_matrix.num_rows, // The leading dimension of RHS. + /*beta=*/0.0, &output_data, + output_matrix + .num_rows) // The leading dimension of the output matrix. + .ok(); + if (!launch_ok) { + return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); + } + return tensorflow::Status::OK(); +} + +// Return, if the given type is a valid Gemm elemental type, the executor for +// that type, else null. +// TODO(b/27202055): consider more element types. +std::function +FindGemmExecutor(PrimitiveType type) { + switch (type) { + case F32: + return &DoGemm; + case F64: + return &DoGemm; + default: + return nullptr; + } +} + +} // namespace + +GemmThunk::GemmThunk(Index lhs_buffer, Index rhs_buffer, Index output_buffer, + const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape, bool transpose_lhs, + bool transpose_rhs, const HloInstruction* hlo_instruction) + : Thunk(Kind::kGemm, hlo_instruction), + lhs_buffer_(lhs_buffer), + rhs_buffer_(rhs_buffer), + output_buffer_(output_buffer), + lhs_shape_(lhs_shape), + rhs_shape_(rhs_shape), + output_shape_(output_shape), + transpose_lhs_(transpose_lhs), + transpose_rhs_(transpose_rhs) {} + +tensorflow::Status GemmThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + VLOG(2) << "Executing a GemmThunk"; + auto executor = FindGemmExecutor(output_shape_.element_type()); + DCHECK(executor != nullptr); + + se::DeviceMemoryBase lhs_data = + buffer_allocations.GetDeviceAddress(lhs_buffer_); + se::DeviceMemoryBase rhs_data = + buffer_allocations.GetDeviceAddress(rhs_buffer_); + se::DeviceMemoryBase output_data = + buffer_allocations.GetDeviceAddress(output_buffer_); + + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between + // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of + // their layout. Therefore, we should treat dimension 0 as row and dimension 1 + // as column when mapping a matrix Dot to BLAS gemm. + int64 output_num_rows = output_shape_.dimensions(0); + int64 output_num_cols = output_shape_.dimensions(1); + + // BLAS gemm expects the inputs and the output are in column-major order. + // Therefore, we need to convert dot between row-major matrices to that + // between column-major matrices. The key insight for the conversion is that, + // in linear storage, matrix M in column-major order is identical to the + // tranpose of M in row-major order. In other words, + // + // column-major(M) = row-major(M^T). + // + // Leveraging this insight, we can perform dot between row-major matrices as + // follows. + // + // row-major(C) + // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T) + // = gemm(column-major(B^T), column-major(A^T)) + // = gemm(row-major(B), row-major(A)) + // + // Although we do not modify the content of A and B in linear memory, we + // should use the dimensions of B^T and A^T when calling gemm. For example, + // the leading dimension of the LHS matrix of gemm is the number of rows in + // B^T and thus the number of columns in B. + + auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = shape.layout().minor_to_major(0) != 0; + bool layout_mismatch = shape.layout().minor_to_major(0) != + output_shape_.layout().minor_to_major(0); + return MatrixDescriptor(data, transpose ^ layout_mismatch, + shape.dimensions(is_row_major), + shape.dimensions(!is_row_major)); + }; + + const MatrixDescriptor lhs_descriptor = + make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); + const MatrixDescriptor rhs_descriptor = + make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + if (output_shape_.layout().minor_to_major(0) == 0) { + return executor( + lhs_descriptor, rhs_descriptor, + MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), + stream); + } else { + return executor( + rhs_descriptor, lhs_descriptor, + MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), + stream); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h new file mode 100644 index 0000000000..7c8574d275 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a BLAS gemm. +// It is generated by IrEmitter. +// +// This is thread-compatible. +class GemmThunk : public Thunk { + public: + // Constructs a thunk that computes "output = lhs rhs" using BLAS gemm. + // transpose_lhs and transpose_rhs indicate whether gemm should transpose the + // lhs and rhs operand. hlo_instruction is as in Thunk. + GemmThunk(BufferAllocation::Index lhs_buffer, + BufferAllocation::Index rhs_buffer, + BufferAllocation::Index output_buffer, const Shape& lhs_shape, + const Shape& rhs_shape, const Shape& output_shape, + bool transpose_lhs, bool transpose_rhs, + const HloInstruction* hlo_instruction); + + GemmThunk(const GemmThunk&) = delete; + GemmThunk& operator=(const GemmThunk&) = delete; + + // Does the gemm operation for the thunk on "stream", which must be non-null. + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Index lhs_buffer_; + BufferAllocation::Index rhs_buffer_; + BufferAllocation::Index output_buffer_; + + Shape lhs_shape_; + Shape rhs_shape_; + Shape output_shape_; + + bool transpose_lhs_; + bool transpose_rhs_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc new file mode 100644 index 0000000000..a13279c6ff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -0,0 +1,335 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" + +#include +#include +#include + +#include "external/llvm/include/llvm/IR/DiagnosticInfo.h" +#include "external/llvm/include/llvm/IR/DiagnosticPrinter.h" +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/subprocess.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +namespace { + +// The triple that represents our target. +const char* kTargetTriple = "nvptx64-nvidia-cuda"; + +// The data layout of the emitted module. Copied from computeDataLayout in +// NVPTXTargetMachine.cpp. +const char* kDataLayout = "e-i64:64-v16:16-v32:32-n16:32:64"; + +// Returns the directory containing nvvm libdevice files. This function is +// called in GpuCompiler's constructor, so can't return an error. But +// GpuCompiler::Compile will return an error when the wanted libdevice file +// doesn't exist in the folder this function returns. +string GetLibdeviceDir() { + std::vector potential_libdevice_dirs; + // Flag xla_cuda_data_dir specified by the user. + legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); + const string datadir = flags->xla_cuda_data_dir; + if (!datadir.empty()) { + potential_libdevice_dirs.push_back(datadir); + } + potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); + + // Tries all potential libdevice directories in the order they are inserted. + // Returns the first directory that exists in the file system. + for (const string& potential_libdevice_dir : potential_libdevice_dirs) { + if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; + return potential_libdevice_dir; + } + VLOG(2) << "Unable to find potential libdevice dir " + << potential_libdevice_dir; + } + + // Last resort: maybe in the current folder. + return "."; +} + +// Runs optimization passes on the given HLO module. +tensorflow::Status OptimizeHloModule(HloModule* hlo_module, + const Compiler::HloDumper& dump_hlo, + const se::DeviceDescription& device_desc) { + { + HloPassPipeline pipeline("optimization", dump_hlo); + { + auto& pass = pipeline.AddPass>( + "simplification", dump_hlo); + pass.AddPass( + /*is_layout_sensitive=*/false, + [](const Shape&, const Shape&) { return false; }); + pass.AddPass(); + } + pipeline.AddPass(); + pipeline.AddPass(ImplementedAsGemm); + pipeline.AddPass(); + pipeline.AddPass(/*is_layout_sensitive=*/false); + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { + HloPassFix fusion("fusion", dump_hlo); + fusion.AddPass(/*may_duplicate=*/false); + fusion.AddPass(/*may_duplicate=*/true); + return fusion.Run(hlo_module).status(); + } +} + +// Modifies the given HLO module so that it will be accepted by IrEmitter. +// Unlike optimization passes, the passes are necessary for correctness. +tensorflow::Status PrepareHloModuleForIrEmitting( + const Compiler::HloDumper& dump_hlo, HloModule* hlo_module, + HloModuleConfig* module_config) { + // In some cases, we have to place the result of an instruction in a temporary + // buffer. For instance, the buffer that holds an external parameter is + // assumed immutable at this point, and should not be reused for output + // (b/27180329). Therefore, in that case, we set the output to be a copy of + // the parameter. + HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo); + pipeline.AddPass(); + pipeline.AddPass( + module_config->mutable_entry_computation_layout()); + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return true; }); + pipeline.AddPass(/*is_layout_sensitive=*/true); + // Copy insertion should be performed immediately before IR emission to avoid + // inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes an + // instruction which materializes a value). + pipeline.AddPass(); + pipeline.AddPass(); + return pipeline.Run(hlo_module).status(); +} + +// Invokes the ptxas tool on the given PTX string, and dumps its output. +void DumpPtxasInfo(const string& ptx) { + legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); + const string ptxas_path = flags->xla_ptxas_path; + // Do not log PTX stats if ptxas is not found at the given path. + if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { + LOG(WARNING) + << "Failed to dump PTX stats because ptxas is not found at path \"" + << ptxas_path << "\"."; + return; + } + + // Write `ptx` into a temporary file. + char tempdir_template[] = "/tmp/ptxXXXXXX"; + char* tempdir_name = mkdtemp(tempdir_template); + CHECK_NOTNULL(tempdir_name); + string ptx_path = tensorflow::io::JoinPath(tempdir_name, "ptx"); + TF_CHECK_OK( + tensorflow::WriteStringToFile(tensorflow::Env::Default(), ptx_path, ptx)); + LOG(INFO) << "ptx file written to: " << ptx_path; + + // Invoke ptxas and collect its output. + tensorflow::SubProcess ptxas_info_dumper; + ptxas_info_dumper.SetProgram(ptxas_path, {ptxas_path, ptx_path, "-o", + "/dev/null", "-v", "-arch=sm_35"}); + ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, + tensorflow::ACTION_PIPE); + CHECK(ptxas_info_dumper.Start()); + string stderr_output; + int exit_status = ptxas_info_dumper.Communicate( + /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); + XLA_LOG_LINES(tensorflow::INFO, stderr_output); + if (exit_status != 0) { + LOG(FATAL) << "Invalid PTX. See the error message above for reasons."; + } +} + +} // namespace + +GpuCompiler::GpuCompiler() : libdevice_dir_(GetLibdeviceDir()) {} + +StatusOr> GpuCompiler::Compile( + std::unique_ptr hlo_module, + std::unique_ptr module_config, HloDumper dump_hlo, + se::StreamExecutor* stream_exec) { + TF_RET_CHECK(stream_exec != nullptr); + + TF_RETURN_IF_ERROR(OptimizeHloModule(hlo_module.get(), dump_hlo, + stream_exec->GetDeviceDescription())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, hlo_module.get(), + module_config.get())); + + llvm::LLVMContext llvm_context; + std::string buffer; + llvm::raw_string_ostream error(buffer); + llvm::DiagnosticPrinterRawOStream printer(error); + auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info, + void* Context) { + auto printer = static_cast(Context); + diag_info.print(*printer); + }; + llvm_context.setDiagnosticHandler(DiagnosticHandler, &printer); + + llvm::Module llvm_module(hlo_module->name().c_str(), llvm_context); + // Set the target triple and the data layout. + llvm_module.setTargetTriple(kTargetTriple); + llvm_module.setDataLayout(kDataLayout); + const llvm::DataLayout& data_layout = llvm_module.getDataLayout(); + int64 pointer_size = data_layout.getPointerSize(); + + // Determine the HLO schedule, which is an ordering of HLO instructions. This + // is used by buffer assignment to enable buffer reuse, and the same ordering + // must also be used to determine the thunk launch schedule. + std::unique_ptr stream_assignment = + AssignStreams(*hlo_module); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_schedule, + HloSchedule::Build(*hlo_module, *stream_assignment)); + + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer_assignment, + BufferAssigner::Run(hlo_module.get(), hlo_schedule->ConsumeHloOrdering(), + pointer_size)); + auto temp_buffer_offsets = MakeUnique(*buffer_assignment); + + IrEmitterContext ir_emitter_context( + hlo_module.get(), buffer_assignment.get(), temp_buffer_offsets.get(), + &stream_exec->GetDeviceDescription(), &llvm_module); + + HloComputation* entry_computation = hlo_module->entry_computation(); + IrEmitterUnnested ir_emitter(*module_config, entry_computation, + module_config->has_hybrid_result(), + &ir_emitter_context); + TF_RETURN_IF_ERROR( + entry_computation->root_instruction()->Accept(&ir_emitter)); + + string ir_module_string_before_opt; + legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); + if (VLOG_IS_ON(2) || flags->xla_gpu_embed_ir) { + ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); + VLOG(2) << "LLVM module before optimizations:"; + XLA_VLOG_LINES(2, ir_module_string_before_opt); + } + + // Reserve space for the PTX to be generated for this module. + string* ptx; + { + tensorflow::mutex_lock lock(mutex_); + generated_ptxes_.emplace_back(MakeUnique()); + ptx = generated_ptxes_.back().get(); + } + TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, libdevice_dir_)); + + VLOG(2) << "LLVM module after optimizations:"; + XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); + VLOG(2) << "PTX:"; + XLA_VLOG_LINES(2, *ptx); + if (VLOG_IS_ON(2)) { + DumpPtxasInfo(*ptx); + } + + auto thunk_schedule = MakeUnique( + ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), + hlo_schedule->ThunkLaunchOrder()); + VLOG(2) << "Printing the thunk schedule..."; + XLA_VLOG_LINES(2, thunk_schedule->ToString()); + + auto* gpu_executable = + new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(hlo_module), + std::move(module_config), std::move(buffer_assignment), + std::move(temp_buffer_offsets)); + if (flags->xla_gpu_embed_ir) { + DCHECK_NE("", ir_module_string_before_opt); + gpu_executable->set_ir_module_string(ir_module_string_before_opt); + } + return std::unique_ptr(gpu_executable); +} + +StatusOr>> GpuCompiler::Compile( + std::vector> hlo_modules, + std::vector> module_configs, + HloDumper dump_hlos, std::vector stream_execs) { + return Unimplemented( + "Compilation of multiple HLO modules is not yet supported on GPU."); +} + +StatusOr> GpuCompiler::CompileAheadOfTime( + std::unique_ptr module, + std::unique_ptr module_config, HloDumper dump_hlo, + const AotCompilationOptions& options) { + return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); +} + +se::Platform::Id GpuCompiler::PlatformId() const { + return se::cuda::kCudaPlatformId; +} + +} // namespace gpu +} // namespace xla + +static bool InitModule() { + xla::Compiler::RegisterCompilerFactory(se::cuda::kCudaPlatformId, []() { + return xla::MakeUnique(); + }); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h new file mode 100644 index 0000000000..fefa403104 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace gpu { + +// The GPU compiler generates efficient GPU executables. +class GpuCompiler : public Compiler { + public: + GpuCompiler(); + ~GpuCompiler() override {} + + StatusOr> Compile( + std::unique_ptr hlo_module, + std::unique_ptr module_config, HloDumper dump_hlo, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr>> Compile( + std::vector> hlo_module, + std::vector> module_config, + HloDumper dump_hlo, + std::vector stream_exec) override; + + StatusOr> CompileAheadOfTime( + std::unique_ptr module, + std::unique_ptr module_config, HloDumper dump_hlo, + AotCompilationOptions const& options) override; + + perftools::gputools::Platform::Id PlatformId() const override; + + private: + // The parent directory of libdevice IR libraries. + const string libdevice_dir_; + + // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler + // to own them because they need to be alive across the life span of the + // StreamExecutor (b/24776264). + tensorflow::mutex mutex_; + std::vector> generated_ptxes_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc new file mode 100644 index 0000000000..f654ffd22d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -0,0 +1,454 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { +namespace { + +// A helper class for profiling HLO in the course of GPU program execution. +// All of the profiling is guarded internally, to avoid the caller needing to +// have lots of conditionals sprinkled around. +class HloExecutionProfiler { + public: + // If profiling is enabled, start an execution timer running. + explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, + se::Stream* stream) + : do_profile_(do_profile), profile_(profile), stream_(stream) { + if (do_profile_) { + clock_rate_ghz_ = + stream->parent()->GetDeviceDescription().clock_rate_ghz(); + execution_timer_.reset(new se::Timer(stream->parent())); + per_op_timer_.reset(new se::Timer(stream->parent())); + stream->InitTimer(execution_timer_.get()) + .ThenStartTimer(execution_timer_.get()); + stream->InitTimer(per_op_timer_.get()); + } + } + + // If profiling is enabled, sets the total cycle count on the profile from the + // execution timer. + ~HloExecutionProfiler() { + if (do_profile_) { + stream_->ThenStopTimer(execution_timer_.get()); + stream_->BlockHostUntilDone(); + profile_->set_total_cycles_executed(execution_timer_->Nanoseconds() * + clock_rate_ghz_); + } + } + + // If profiling is enabled, starts the per-operation timer. + void StartOperation() { + if (do_profile_) { + stream_->ThenStartTimer(per_op_timer_.get()); + } + } + + // If profiling is enabled, stops the per-operation timer and records the time + // that the hlo_instruction took to execute in the profile. + void FinishOperation(const HloInstruction* hlo_instruction) { + if (do_profile_) { + stream_->ThenStopTimer(per_op_timer_.get()); + stream_->BlockHostUntilDone(); + profile_->AddProfileResult( + hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); + } + } + + private: + const bool do_profile_; + double clock_rate_ghz_; + HloExecutionProfile* profile_; + se::Stream* stream_; + std::unique_ptr execution_timer_; + std::unique_ptr per_op_timer_; +}; + +} // namespace + +// Implementation note: HLO profiling is always enabled for GPU executables, +// since we can use timers around thunks. +GpuExecutable::GpuExecutable( + tensorflow::StringPiece ptx, std::unique_ptr thunk_schedule, + std::unique_ptr hlo_module, + std::unique_ptr module_config, + std::unique_ptr assignment, + std::unique_ptr temp_buffer_offsets) + : Executable(std::move(hlo_module), std::move(module_config)), + ptx_(ptx), + thunk_schedule_(std::move(thunk_schedule)), + assignment_(std::move(assignment)), + temp_buffer_offsets_(std::move(temp_buffer_offsets)) {} + +Status GpuExecutable::ExecuteThunks( + se::Stream* main_stream, const BufferAllocations& buffer_allocations, + HloExecutionProfile* hlo_execution_profile) { + bool do_profile = hlo_execution_profile != nullptr; + if (do_profile) { + LOG(WARNING) << "PROFILING: profiling is enabled"; + } + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream); + + std::vector> sub_streams; + // Stream 0 indicates `main_stream` and substreams start from stream 1. + for (int32 i = 1; i < thunk_schedule_->StreamCount(); ++i) { + auto sub_stream = MakeUnique(main_stream->parent()); + sub_stream->Init(); + sub_streams.emplace_back(std::move(sub_stream)); + } + + std::map> thunk_to_finish_event; + for (Thunk* thunk : thunk_schedule_->TotalOrder()) { + TF_RETURN_IF_ERROR(thunk->Initialize(*this)); + int32 stream_no = + thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); + se::Stream* stream = + (stream_no == 0 ? main_stream : sub_streams[stream_no - 1].get()); + + for (const Thunk* dependency : thunk_schedule_->DependsOn(thunk)) { + stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); + } + + profiler.StartOperation(); + VLOG(2) << "Executing the thunk for " + << thunk->hlo_instruction()->ToString(); + TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + if (thunk_schedule_->Depended(thunk)) { + auto finish_event = MakeUnique(main_stream->parent()); + finish_event->Init(); + stream->ThenRecordEvent(finish_event.get()); + thunk_to_finish_event[thunk] = std::move(finish_event); + } + profiler.FinishOperation(thunk->hlo_instruction()); + } + + main_stream->ThenWaitFor(&sub_streams); + // Make sure kernels are completed before deallocating temporary buffers. + // TODO(b/30100571): we could potentially postpone deallocating the temp + // buffers until a different computation is executed. + if (!main_stream->BlockHostUntilDone()) { + return InternalError("Failed to complete all kernels launched on stream %p", + main_stream); + } + + return Status::OK(); +} + +StatusOr GpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + // This ExecuteOnStream overload should only be called if has_hybrid_result is + // false. + TF_RET_CHECK(!module_config().has_hybrid_result()); + + BufferAllocations::Builder buffer_allocations_builder; + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_entry_computation_parameter()) { + buffer_allocations_builder.RegisterBuffer( + i, arguments[allocation.parameter_number()]); + } + } + se::StreamExecutor* executor = stream->parent(); + TF_ASSIGN_OR_RETURN(auto buffer_allocations, + buffer_allocations_builder.Build( + *assignment_, *temp_buffer_offsets_, + executor->device_ordinal(), memory_allocator)); + + TF_RETURN_IF_ERROR( + ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile)); + + HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); + TF_ASSIGN_OR_RETURN(const BufferAllocation* output_allocation, + assignment_->GetUniqueTopLevelOutputAllocation()); + se::DeviceMemoryBase output_buffer_address = + buffer_allocations->GetDeviceAddress(output_allocation->index()); + + if (ShapeUtil::IsTuple(root->shape())) { + std::set referred_by_output; + if (GetRootPointsToSet().IsAmbiguous()) { + // The points-to set of the root is ambiguous so we need to examine the + // result data to determine which buffers are contained in the result. + TF_ASSIGN_OR_RETURN( + TransferManager * transfer_manager, + TransferManager::GetForPlatform(executor->platform())); + TF_ASSIGN_OR_RETURN(referred_by_output, + transfer_manager->GatherBufferPointersFromTuple( + executor, output_buffer_address, root->shape())); + } else { + // The points-to set of the root is unambiguous so it's known statically + // which buffers are in the result. Gather these buffers using the root's + // points-to set. + TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElement( + [&referred_by_output, &buffer_allocations, this]( + const ShapeIndex& /*index*/, bool /*is_leaf*/, + const std::vector& buffers) { + // The points to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction produced + // the array at this element. + CHECK_EQ(1, buffers.size()); + HloInstruction* hlo = buffers[0]->instruction(); + TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, + this->assignment_->GetUniqueAllocation( + hlo, buffers[0]->index())); + CHECK(!allocation->is_entry_computation_parameter()); + referred_by_output.insert( + buffer_allocations->GetDeviceAddress(allocation->index())); + return Status::OK(); + })); + } + TF_RETURN_IF_ERROR( + buffer_allocations->TearDown(referred_by_output, *assignment_)); + } else { + // If the computation result is not a tuple, we can delete all temporary + // buffers that are not the output. + TF_RETURN_IF_ERROR( + buffer_allocations->TearDown({output_buffer_address}, *assignment_)); + } + return output_buffer_address; +} + +StatusOr> GpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + // This ExecuteOnStream overload should only be called by the LocalService + // which sets has_hybrid_result to true. + TF_RET_CHECK(module_config().has_hybrid_result()); + + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } + + BufferAllocations::Builder buffer_allocations_builder; + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_entry_computation_parameter()) { + auto param_no = allocation.parameter_number(); + if (ShapeUtil::IsTuple(arguments[param_no]->shape())) { + return Unimplemented("Tuple ShapedBuffer arguments not supported"); + } + buffer_allocations_builder.RegisterBuffer( + i, arguments[param_no]->buffer(/*index=*/{})); + } + } + se::StreamExecutor* executor = stream->parent(); + TF_ASSIGN_OR_RETURN(auto buffer_allocations, + buffer_allocations_builder.Build( + *assignment_, *temp_buffer_offsets_, + executor->device_ordinal(), memory_allocator)); + + TF_RETURN_IF_ERROR( + ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile)); + + HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); + auto device_ordinal = executor->device_ordinal(); + TF_ASSIGN_OR_RETURN(auto shaped_buffer, + ShapedBuffer::MakeShapedBuffer( + root->shape(), executor->platform(), device_ordinal)); + + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer. + std::set buffers_in_result; + TF_RETURN_IF_ERROR( + shaped_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { + if (is_leaf) { + const std::vector& sources = + this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, + this->assignment_->GetUniqueAllocation( + src_hlo, sources[0]->index())); + CHECK(!allocation->is_entry_computation_parameter()); + + perftools::gputools::DeviceMemoryBase src_base = + buffer_allocations->GetDeviceAddress(allocation->index()); + CHECK(!src_base.is_null() || src_base.size() == 0); + shaped_buffer->mutable_buffers()->push_back(src_base); + *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; + + buffers_in_result.insert(src_base); + } + return Status::OK(); + })); + TF_RETURN_IF_ERROR( + buffer_allocations->TearDown(buffers_in_result, *assignment_)); + + return std::move(shaped_buffer); +} + +Status GpuExecutable::ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + // This ExecuteOnStream overload should only be called by the LocalService + // which sets has_hybrid_result to true. + TF_RET_CHECK(module_config().has_hybrid_result()); + + // Every array element in the result of the computation must be unambiguously + // produced by a single instruction. + // This ensures that the buffers inside result_buffer can be assigned without + // conflict to the respective instructions because there is a one-to-one + // correspondence between hlo instructions and array buffers in the result. + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented( + "Points-to set of root instruction is ambiguous or not distinct"); + } + + DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape())); + + BufferAllocations::Builder buffer_allocations_builder; + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_entry_computation_parameter()) { + auto param_no = allocation.parameter_number(); + if (ShapeUtil::IsTuple(arguments[param_no]->shape())) { + return Unimplemented("Tuple ShapedBuffer arguments not supported"); + } + buffer_allocations_builder.RegisterBuffer( + i, arguments[param_no]->buffer(/*index=*/{})); + } + } + + // If two tuple elements point to the same buffer, one of the results in the + // result buffer is considered the canonical location while the other result + // points to it (instead of, say, making a copy of the result). + // buffer_index_to_shape_index maps a buffer index to its canonical location + // in the result buffer. + std::unordered_map + buffer_index_to_shape_index; + + // Register DeviceMemoryBase values in result_buffer to their corresponding + // buffer indices. These buffers will not be allocated in the call to + // BufferAllocationsBuilder::Build. + std::set buffers_in_result; + TF_RETURN_IF_ERROR( + result_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [&buffer_allocations_builder, &buffers_in_result, + &buffer_index_to_shape_index, result_buffer, this]( + const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { + if (is_leaf) { + const std::vector& sources = + this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, + this->assignment_->GetUniqueAllocation( + src_hlo, sources[0]->index())); + CHECK(!allocation->is_entry_computation_parameter()); + + auto insert_result = buffer_index_to_shape_index.emplace( + allocation->index(), *buffer_entry); + if (insert_result.second) { + // The points-to set is distinct so this buffer should not + // have been assigned in a previous invocation of this + // lambda. + perftools::gputools::DeviceMemoryBase memory_base = + result_buffer->buffer(index); + CHECK(!memory_base.is_null()); + buffer_allocations_builder.RegisterBuffer( + allocation->index(), memory_base); + buffers_in_result.insert(memory_base); + } else { + // Record the fact that this tuple element is identical to + // some + // prior result. + *buffer_entry = insert_result.first->second; + } + } + return Status::OK(); + })); + + se::StreamExecutor* executor = stream->parent(); + auto device_ordinal = executor->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto buffer_allocations, + buffer_allocations_builder.Build(*assignment_, *temp_buffer_offsets_, + device_ordinal, memory_allocator)); + + TF_RETURN_IF_ERROR( + ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile)); + + return buffer_allocations->TearDown(buffers_in_result, *assignment_); +} + +StatusOr GpuExecutable::ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments) { + // TODO(b/30671675): Implement asynchronous execution mode. + return Unimplemented( + "Asynchronous execution on stream is not yet supported on GPU."); +} + +const PointsToSet& GpuExecutable::GetRootPointsToSet() const { + return assignment_->points_to_analysis().GetPointsToSet( + module().entry_computation()->root_instruction()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h new file mode 100644 index 0000000000..2343d264de --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -0,0 +1,130 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// GPU-targeting implementation of the XLA Executable interface. +// +// Launches the given CUDA kernel via the StreamExecutor. +// +// This is an immutable data type after initialization, and thus thread safe. +class GpuExecutable : public Executable { + public: + GpuExecutable(tensorflow::StringPiece ptx, + std::unique_ptr thunk_schedule, + std::unique_ptr hlo_module, + std::unique_ptr module_config, + std::unique_ptr assignment, + std::unique_ptr temp_buffer_offsets); + + // This should be called after set_ir_module_string. + const string& ir_module_string() const { return ir_module_string_; } + + // This should be called before ExecuteOnStream. + void set_ir_module_string(const string& ir_module_string) { + ir_module_string_ = ir_module_string; + } + + // Returns the compiled PTX for the computation. + tensorflow::StringPiece ptx() const { return ptx_; } + + StatusOr ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr> ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) override; + + Status ExecuteOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result_buffer, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr ExecuteAsyncOnStream( + const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments) override; + + private: + Status ExecuteThunks(perftools::gputools::Stream* stream, + const BufferAllocations& buffer_allocations, + HloExecutionProfile* hlo_execution_profile); + + // Returns the points-to set of the root instruction of the entry + // computation. Uses points-to analysis from buffer assignment. + const PointsToSet& GetRootPointsToSet() const; + + // The LLVM IR, in string format, of the unoptimized module generated for this + // GpuExecutable. We save a string instead of an llvm::Module* because leaving + // llvm::Module* in a singleton can cause the heap checker to emit false + // positives. + // + // This string should be modified only before ExecuteOnStream. + string ir_module_string_; + + // The reference to the compiled PTX for the computation. + const tensorflow::StringPiece ptx_; + + // The thunks to be invoked by this GpuExecutable. They are generated by the + // IrEmitter. + const std::unique_ptr thunk_schedule_; + + // Owns the buffer data at runtime. It provides information to allocate + // memory for every output/temp buffers. + const std::unique_ptr assignment_; + + // Owns the mapping from temporary buffers to their offsets in the temp-buffer + // memory block. + const std::unique_ptr temp_buffer_offsets_; + + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc new file mode 100644 index 0000000000..c56b0ee9ca --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -0,0 +1,207 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +namespace { + +// An HLO partial ordering based on the actual stream assignment and thunk +// launch order. +class GpuHloOrdering : public PredecessorHloOrdering { + public: + GpuHloOrdering(const HloModule* module, + const StreamAssignment& stream_assignment, + const std::vector& thunk_launch_order); + ~GpuHloOrdering() override = default; + + string ToString() const override { return ToStringHelper("GpuHloOrdering"); } +}; + +GpuHloOrdering::GpuHloOrdering( + const HloModule* module, const StreamAssignment& stream_assignment, + const std::vector& thunk_launch_order) + : PredecessorHloOrdering(module) { + // The ordering of instructions for the entry computation is determined by the + // total order of thunk launches, and stream assignment. Instructions are + // sequential within a stream and concurrent across streams. In addition, the + // GpuExecutable adds cross-stream dependency edges to ensure each instruction + // waits for its operands before executing. + // + // The predecessor map is built incrementally, in thunk launch + // order. We record the instructions already visited per stream in + // 'instructions_per_stream'. This lets us quickly determine the + // same-stream predecessors of each instruction. To capture + // cross-stream dependency edges, we use the predecessor map to + // insert each operand as well as its transitive closure of + // dependencies. + + // Compute the set of all instructions we will want to set reachability on + auto predecessor_map = MakeUnique( + module->entry_computation()->MakeInstructionPostOrder()); + + std::vector> instructions_per_stream( + stream_assignment.StreamCount()); + + for (const HloInstruction* hlo : thunk_launch_order) { + if (stream_assignment.HasStreamAssigned(*hlo)) { + // All ops already queued on the same stream are predecessors. + const int stream_no = stream_assignment.StreamNumberForHlo(*hlo); + for (const HloInstruction* inst : instructions_per_stream[stream_no]) { + predecessor_map->SetReachable(hlo, inst); + } + // All operands and their transitive predecessors are predecessors. Each + // operand must already exist in 'predecessor_map', since we're iterating + // in thunk launch order. + for (const HloInstruction* operand : hlo->operands()) { + predecessor_map->SetReachableAndTransitiveClosure(hlo, operand); + } + instructions_per_stream[stream_no].push_back(hlo); + } else { + // Only parameters and constants don't have an assigned stream, since they + // don't require a thunk. These ops don't have any predecessors. + CHECK(hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant); + CHECK_EQ(hlo->operand_count(), 0); + } + } + strict_predecessors_.emplace(module->entry_computation(), + std::move(predecessor_map)); + + // The ordering of instructions in subcomputations is based solely on data + // dependencies. I.e. the strict predecessors of each subcomputation + // instruction is its transitive operands. + // + // TODO(toddw): Each subcomputation is actually emitted as a function in + // DFS + // postorder, so we can do better and establish the total order here. We + // don't + // do that yet since it's hard to ensure that the order here is the order + // used + // by IrEmitterNested. And mismatched ordering bugs would be hard to find. + for (auto& computation : module->computations()) { + if (computation.get() != module->entry_computation()) { + strict_predecessors_.emplace(computation.get(), + computation->ComputeTransitiveOperands()); + } + } +} + +// Computes a topological launch_order based on depth-first order, visiting +// operands in essentially an arbitrary order. +// +// TODO(b/32006145): Use an ordering that minimizes memory pressure. +tensorflow::Status DFSLaunchOrder( + const HloComputation* computation, + std::vector* launch_order) { + return computation->root_instruction()->Accept( + [launch_order](HloInstruction* hlo) { + launch_order->push_back(hlo); + return tensorflow::Status::OK(); + }); +} + +// Computes a topological launch_order that is close to a breadth-first +// order. This heuristic works well for graphs where concurrent kernels are +// located at the same layer. It can often reduce dependency between concurrent +// GEMMs due to intra-stream total orders. E.g. consider the following HLO +// graph where the numbers in the parens indicate the stream assigned to each +// HLO. +// +// A(0) -> D(0) -> E(1) +// | +// v +// B(0) +// | +// v +// C(0) +// +// If the total order is A,B,C,D,E, then C and E would be sequentialized +// because C completes before D starts in stream 0, and E depends on D. +// However, if the total order is A,B,D,C,E, then C and E can run +// concurrently. +void BFSLaunchOrder(const HloComputation* computation, + std::vector* launch_order) { + // This topological sort uses two data structures: + // 1. `incoming_edge_count` which keeps track of the number of incoming + // edges to each HLO; + // 2. `queue` which contains all HLOs with no incoming edges. + // + // The sorting algorithm repeatedly pops the top from the queue and deletes + // that HLO from the graph, making more HLOs incoming-edge free. + std::deque queue; + std::unordered_map incoming_edge_count; + for (const auto& hlo : computation->instructions()) { + if (hlo->operand_count() == 0) { + queue.push_back(hlo.get()); + } else { + incoming_edge_count[hlo.get()] = + std::set(hlo->operands().begin(), + hlo->operands().end()) + .size(); + } + } + + while (!queue.empty()) { + const HloInstruction* x = queue.front(); + queue.pop_front(); + launch_order->push_back(x); + for (const HloInstruction* y : x->users()) { + --incoming_edge_count[y]; + if (incoming_edge_count[y] == 0) { + queue.push_back(y); + } + } + } +} + +} // end namespace + +HloSchedule::HloSchedule() {} + +/* static */ +StatusOr> HloSchedule::Build( + const HloModule& module, const StreamAssignment& stream_assignment) { + std::unique_ptr schedule(new HloSchedule); + + // Initialize thunk_launch_order_, the total order of thunk launches. + const HloComputation* computation = module.entry_computation(); + if (stream_assignment.StreamCount() == 1) { + // DFS tends to increase buffer reuse, reducing memory usage. All kernels + // are launched on a single stream, so there's no loss of concurrency. + TF_RETURN_IF_ERROR( + DFSLaunchOrder(computation, &schedule->thunk_launch_order_)); + } else { + // BFS tends to increase concurrency, but also increases memory usage. + BFSLaunchOrder(computation, &schedule->thunk_launch_order_); + } + + schedule->hlo_ordering_ = MakeUnique( + &module, stream_assignment, schedule->thunk_launch_order_); + + return std::move(schedule); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h new file mode 100644 index 0000000000..42d9051aed --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace gpu { + +// Determines the schedule of HLO instructions, represented by the total order +// of thunk launches, and the partial order of HLO instructions. The HLO +// instructions are only partially ordered, despite the total ordering of thunk +// launches, because thunks may be scheduled onto concurrent streams. This +// schedule is used by BufferAssigner to determine buffer liveness (i.e. to +// minimize allocations), and also by ThunkSchedule to determine the thunk +// launch order. +class HloSchedule { + public: + // Constructs an HloSchedule for the given module, based on the given stream + // assignment. + static StatusOr> Build( + const HloModule& module, const StreamAssignment& stream_assignment); + + // Returns the total order of thunk launches, represented in terms of HLO + // instructions. + const std::vector& ThunkLaunchOrder() const { + return thunk_launch_order_; + } + + // Returns the partial order of HLO instructions. This method may only be + // called once. The order is based on the total order of thunk lanches, the + // stream assignment, and the data dependencies in the HLO DAG. + std::unique_ptr ConsumeHloOrdering() { + return std::move(hlo_ordering_); + } + + private: + HloSchedule(); + + std::vector thunk_launch_order_; + std::unique_ptr hlo_ordering_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc new file mode 100644 index 0000000000..174982a6ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -0,0 +1,368 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" + +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +class HloScheduleTest : public HloTestBase { + protected: + typedef std::vector hlovec; + + // Pre-canned shapes. + Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); +}; + +// Test of a single stream, where data dependencies fully determine the +// execution order. +TEST_F(HloScheduleTest, SequentialMatMul) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); + HloInstruction* dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction* dot2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(dot2)); + + std::unique_ptr streams = AssignStreams(module); + EXPECT_EQ(streams->StreamNumberForHlo(*dot1), + streams->StreamNumberForHlo(*dot2)); + + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, dot1, z, dot2})); + + // Parameters x,y,z are mutually unordered, while dot1 and dot2 are + // transitively ordered by operands. + auto order = schedule->ConsumeHloOrdering(); + EXPECT_TRUE(order->ExecutesBefore(x, dot1)); + EXPECT_TRUE(order->ExecutesBefore(x, dot2)); + EXPECT_TRUE(order->ExecutesBefore(y, dot1)); + EXPECT_TRUE(order->ExecutesBefore(y, dot2)); + EXPECT_TRUE(order->ExecutesBefore(z, dot2)); + EXPECT_TRUE(order->ExecutesBefore(dot1, dot2)); + + EXPECT_FALSE(order->ExecutesBefore(x, x)); + EXPECT_FALSE(order->ExecutesBefore(x, y)); + EXPECT_FALSE(order->ExecutesBefore(x, z)); + EXPECT_FALSE(order->ExecutesBefore(y, x)); + EXPECT_FALSE(order->ExecutesBefore(y, y)); + EXPECT_FALSE(order->ExecutesBefore(y, z)); + EXPECT_FALSE(order->ExecutesBefore(z, x)); + EXPECT_FALSE(order->ExecutesBefore(z, y)); + EXPECT_FALSE(order->ExecutesBefore(z, z)); + EXPECT_FALSE(order->ExecutesBefore(z, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot1, x)); + EXPECT_FALSE(order->ExecutesBefore(dot1, y)); + EXPECT_FALSE(order->ExecutesBefore(dot1, z)); + EXPECT_FALSE(order->ExecutesBefore(dot1, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot2, x)); + EXPECT_FALSE(order->ExecutesBefore(dot2, y)); + EXPECT_FALSE(order->ExecutesBefore(dot2, z)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot2)); +} + +// Test of a single stream, where data dependencies do not fully determine the +// execution order, but the stream assignment does. +TEST_F(HloScheduleTest, SequentialAdd) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z)); + HloInstruction* add3 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(add3)); + + std::unique_ptr streams = AssignStreams(module); + EXPECT_EQ(streams->StreamNumberForHlo(*add1), + streams->StreamNumberForHlo(*add2)); + EXPECT_EQ(streams->StreamNumberForHlo(*add1), + streams->StreamNumberForHlo(*add3)); + + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, add1, z, add2, add3})); + + // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are + // transitively ordered by operands. + auto order = schedule->ConsumeHloOrdering(); + EXPECT_TRUE(order->ExecutesBefore(x, add1)); + EXPECT_TRUE(order->ExecutesBefore(x, add3)); + EXPECT_TRUE(order->ExecutesBefore(y, add1)); + EXPECT_TRUE(order->ExecutesBefore(y, add2)); + EXPECT_TRUE(order->ExecutesBefore(y, add3)); + EXPECT_TRUE(order->ExecutesBefore(z, add2)); + EXPECT_TRUE(order->ExecutesBefore(z, add3)); + EXPECT_TRUE(order->ExecutesBefore(add1, add3)); + EXPECT_TRUE(order->ExecutesBefore(add2, add3)); + // The HLO graph does not define an ordering for add1 and add2, but their + // assignment onto the same stream does define an ordering. + if (order->ExecutesBefore(add1, add2)) { + EXPECT_FALSE(order->ExecutesBefore(add2, add1)); + } else { + EXPECT_TRUE(order->ExecutesBefore(add2, add1)); + EXPECT_FALSE(order->ExecutesBefore(add1, add2)); + } + + EXPECT_FALSE(order->ExecutesBefore(x, x)); + EXPECT_FALSE(order->ExecutesBefore(x, y)); + EXPECT_FALSE(order->ExecutesBefore(x, z)); + EXPECT_FALSE(order->ExecutesBefore(y, x)); + EXPECT_FALSE(order->ExecutesBefore(y, y)); + EXPECT_FALSE(order->ExecutesBefore(y, z)); + EXPECT_FALSE(order->ExecutesBefore(z, x)); + EXPECT_FALSE(order->ExecutesBefore(z, y)); + EXPECT_FALSE(order->ExecutesBefore(z, z)); + EXPECT_FALSE(order->ExecutesBefore(x, add2)); + EXPECT_FALSE(order->ExecutesBefore(z, add1)); + EXPECT_FALSE(order->ExecutesBefore(add1, x)); + EXPECT_FALSE(order->ExecutesBefore(add1, y)); + EXPECT_FALSE(order->ExecutesBefore(add1, z)); + EXPECT_FALSE(order->ExecutesBefore(add1, add1)); + EXPECT_FALSE(order->ExecutesBefore(add2, x)); + EXPECT_FALSE(order->ExecutesBefore(add2, y)); + EXPECT_FALSE(order->ExecutesBefore(add2, z)); + EXPECT_FALSE(order->ExecutesBefore(add2, add2)); +} + +// Test of two streams. +TEST_F(HloScheduleTest, ConcurrentMatMul) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction* dot2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(add)); + + std::unique_ptr streams = AssignStreams(module); + EXPECT_NE(streams->StreamNumberForHlo(*dot1), + streams->StreamNumberForHlo(*dot2)); + + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + EXPECT_TRUE(schedule->ThunkLaunchOrder() == hlovec({x, y, dot1, dot2, add}) || + schedule->ThunkLaunchOrder() == hlovec({x, y, dot2, dot1, add})); + + // Parameters x,y are mutually unordered, while dot1, dot2 and add are + // transitively ordered by operands. + auto order = schedule->ConsumeHloOrdering(); + EXPECT_TRUE(order->ExecutesBefore(x, dot1)); + EXPECT_TRUE(order->ExecutesBefore(x, dot2)); + EXPECT_TRUE(order->ExecutesBefore(y, dot1)); + EXPECT_TRUE(order->ExecutesBefore(y, dot2)); + EXPECT_TRUE(order->ExecutesBefore(dot1, add)); + EXPECT_TRUE(order->ExecutesBefore(dot2, add)); + + EXPECT_FALSE(order->ExecutesBefore(x, x)); + EXPECT_FALSE(order->ExecutesBefore(x, y)); + EXPECT_FALSE(order->ExecutesBefore(y, x)); + EXPECT_FALSE(order->ExecutesBefore(y, y)); + EXPECT_FALSE(order->ExecutesBefore(dot1, x)); + EXPECT_FALSE(order->ExecutesBefore(dot1, y)); + EXPECT_FALSE(order->ExecutesBefore(dot1, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot1, dot2)); + EXPECT_FALSE(order->ExecutesBefore(dot2, x)); + EXPECT_FALSE(order->ExecutesBefore(dot2, y)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot1)); + EXPECT_FALSE(order->ExecutesBefore(dot2, dot2)); + EXPECT_FALSE(order->ExecutesBefore(add, x)); + EXPECT_FALSE(order->ExecutesBefore(add, y)); + EXPECT_FALSE(order->ExecutesBefore(add, dot1)); + EXPECT_FALSE(order->ExecutesBefore(add, dot2)); + EXPECT_FALSE(order->ExecutesBefore(add, add)); +} + +// Test of multiple streams. +TEST_F(HloScheduleTest, LatticeMatMul) { + // d00 -- layer 0 + // / \ + // d10 d11 -- layer 1 + // / \ / \ + // d20 d21 d22 -- layer 2 + // \ / \ / + // d30 d31 -- layer 3 + // \ / + // d40 -- layer 4 + HloComputation::Builder builder("entry_computation"); + std::vector params; + for (int i = 0; i < 6; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + } + HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d10 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction* d11 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction* d20 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction* d21 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction* d22 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction* d30 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction* d31 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction* d40 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(d40)); + + std::unique_ptr streams = AssignStreams(module); + // The two dots on layer 1 are concurrent. + EXPECT_NE(streams->StreamNumberForHlo(*d10), + streams->StreamNumberForHlo(*d11)); + // The three dots on layer 2 are concurrent. + EXPECT_NE(streams->StreamNumberForHlo(*d20), + streams->StreamNumberForHlo(*d21)); + EXPECT_NE(streams->StreamNumberForHlo(*d20), + streams->StreamNumberForHlo(*d22)); + EXPECT_NE(streams->StreamNumberForHlo(*d21), + streams->StreamNumberForHlo(*d22)); + // The two dots on layer 3 are concurrent. + EXPECT_NE(streams->StreamNumberForHlo(*d30), + streams->StreamNumberForHlo(*d31)); + + // We don't check the thunk launch order, since there are many valid total + // orders, and it's annoying to express. + auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + + auto order = schedule->ConsumeHloOrdering(); + const hlovec all_params( + {params[0], params[1], params[2], params[3], params[4], params[5]}); + const hlovec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40}); + + // Parameters are mutually unordered, and never execute before ops. + for (const HloInstruction* param : all_params) { + for (const HloInstruction* param2 : all_params) { + EXPECT_FALSE(order->ExecutesBefore(param, param2)); + } + for (const HloInstruction* op : all_ops) { + EXPECT_FALSE(order->ExecutesBefore(op, param)); + } + } + + // Check ordering of params before ops. + for (const HloInstruction* op : all_ops) { + if (op == d20 || op == d30 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(params[0], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[0], op)); + } + if (op != d00 && op != d11 && op != d22) { + EXPECT_TRUE(order->ExecutesBefore(params[1], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[1], op)); + } + EXPECT_TRUE(order->ExecutesBefore(params[2], op)); + EXPECT_TRUE(order->ExecutesBefore(params[3], op)); + if (op != d00 && op != d10 && op != d20) { + EXPECT_TRUE(order->ExecutesBefore(params[4], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[4], op)); + } + if (op == d22 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(params[5], op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(params[5], op)); + } + } + + // Check ordering of ops before ops. + for (const HloInstruction* op : all_ops) { + if (op != d00) { + EXPECT_TRUE(order->ExecutesBefore(d00, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d00, op)); + } + + if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d10, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d10, op)); + } + + if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d11, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d11, op)); + } + + if (op == d30 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d20, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d20, op)); + } + + if (op == d30 || op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d21, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d21, op)); + } + + if (op == d31 || op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d22, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d22, op)); + } + + if (op == d40) { + EXPECT_TRUE(order->ExecutesBefore(d30, op)); + EXPECT_TRUE(order->ExecutesBefore(d31, op)); + } else { + EXPECT_FALSE(order->ExecutesBefore(d30, op)); + EXPECT_FALSE(order->ExecutesBefore(d31, op)); + } + + EXPECT_FALSE(order->ExecutesBefore(d40, op)); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc new file mode 100644 index 0000000000..accc406c76 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" + +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +void HloToIrBindings::EmitBasePointersForHlos( + tensorflow::gtl::ArraySlice io_hlos, + tensorflow::gtl::ArraySlice non_io_hlos) { + // I/O HLOs are bound to the arguments of the current IR function. I.e., + // + // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { + llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + CHECK_EQ(io_hlos.size() + 1, function->arg_size()); + + // An HLO can have duplicated operands. This data structure remembers which + // operand HLOs are already bound to avoid rebinding the same HLO. + std::set already_bound_for_this_function; + auto arg_iter = function->arg_begin(); + for (const auto* io_hlo : io_hlos) { + if (!already_bound_for_this_function.count(io_hlo)) { + if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { + BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); + } else { + BindHloToIrValue(*io_hlo, &*arg_iter); + } + already_bound_for_this_function.insert(io_hlo); + } + ++arg_iter; + } + + temp_buffer_base_ = &*arg_iter; + temp_buffer_base_->setName("temp_buffer"); + + for (auto* non_io_hlo : non_io_hlos) { + if (already_bound_for_this_function.count(non_io_hlo)) { + continue; + } + already_bound_for_this_function.insert(non_io_hlo); + + if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) { + if (!is_nested_) { + // Lookup allocation GetTupleElement operand. + const BufferAllocation* allocation = + buffer_assignment_ + ->GetUniqueTopLevelAllocation(LatestNonGteAncestor(non_io_hlo)) + .ConsumeValueOrDie(); + // We are not in a nested context, so check non-thread-local allocation. + CHECK(!allocation->is_thread_local()); + int64 offset = temp_buffer_offsets_->GetOffset(allocation->index()); + CHECK_NE(nullptr, temp_buffer_base_); + // Emit IR for GetTupleElement instruction and bind to emitted value. + llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP( + temp_buffer_base_, ir_builder_->getInt64(offset)); + BindHloToIrValue(*non_io_hlo, + EmitGetTupleElement(non_io_hlo, base_ptr)); + } + continue; + } + + if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) { + continue; + } + + // A non-IO HLO with a buffer is bound to + // (1) an alloca if it is thread-local, or + // (2) an internal pointer in temp_buffer_base according to its offset. + const BufferAllocation* allocation = + buffer_assignment_->GetUniqueTopLevelAllocation(non_io_hlo) + .ConsumeValueOrDie(); + if (allocation->is_thread_local()) { + llvm::Type* pointee_type = + llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); + BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type)); + } else { + int64 offset = temp_buffer_offsets_->GetOffset(allocation->index()); + CHECK_NE(nullptr, temp_buffer_base_); + BindHloToIrValue(*non_io_hlo, + ir_builder_->CreateInBoundsGEP( + temp_buffer_base_, ir_builder_->getInt64(offset))); + } + } +} + +llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, + llvm::Value* base_ptr) { + // TODO(b/26344050): tighten the alignment based on the real element type. + if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { + return llvm_ir::EmitGetTupleElement( + gte->shape(), gte->tuple_index(), /*alignment=*/1, + GetTypedIrValue(*gte->operand(0), base_ptr), ir_builder_); + } + return llvm_ir::EmitGetTupleElement( + gte->shape(), gte->tuple_index(), /*alignment=*/1, + EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_); +} + +llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, + llvm::Value* ir_value) { + llvm::Type* pointee_type = llvm_ir::ShapeToIrType(hlo.shape(), ir_builder_); + llvm::Type* dest_type = pointee_type->getPointerTo(); + + llvm::Value* typed_ir_value; + if (llvm::isa(ir_value)) { + typed_ir_value = llvm::ConstantExpr::getBitCast( + llvm::cast(ir_value), dest_type); + } else { + typed_ir_value = + ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + } + string ir_value_name = llvm_ir::SanitizeIrName(hlo.name()); + ir_value->setName(llvm_ir::AsStringRef(ir_value_name + ".raw")); + typed_ir_value->setName(llvm_ir::AsStringRef(ir_value_name + ".typed")); + return typed_ir_value; +} + +void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, + llvm::Value* ir_value) { + VLOG(2) << "Binding " << hlo.ToString(); + InsertOrDie(&base_ptrs_, &hlo, GetTypedIrValue(hlo, ir_value)); +} + +llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo) { + llvm_ir::IrArray ir_array(GetBasePointer(hlo), hlo.shape()); + alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + return ir_array; +} + +void HloToIrBindings::UnbindAllLocalIrValues() { + std::vector hlos_to_unbind; + for (auto& key_value : base_ptrs_) { + if (!llvm::isa( + key_value.second->stripPointerCasts())) { + hlos_to_unbind.push_back(key_value.first); + } + } + for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) { + VLOG(2) << "Unbinding " << hlo_to_unbind->ToString(); + base_ptrs_.erase(hlo_to_unbind); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h new file mode 100644 index 0000000000..1e3b268423 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_ + +#include + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace gpu { + +// This class encapsulates the bindings between HloInstructions and LLVM IR +// values that represent their addresses. +class HloToIrBindings { + public: + HloToIrBindings(const HloModule& module, + const BufferAssignment* buffer_assignment, + const TempBufferOffsets* temp_buffer_offsets, + llvm::IRBuilder<>* ir_builder, bool is_nested) + : buffer_assignment_(buffer_assignment), + temp_buffer_offsets_(temp_buffer_offsets), + is_nested_(is_nested), + ir_builder_(ir_builder), + alias_analysis_(module, *buffer_assignment_, + &ir_builder_->getContext()) {} + + void EmitBasePointersForHlos( + tensorflow::gtl::ArraySlice io_hlos, + tensorflow::gtl::ArraySlice non_io_hlos); + + // Rebinds the given HLO to the LLVM IR value that represent its address. + void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value); + + // Unbinds all IR values that's defined in an LLVM function, e.g., function + // arguments and stack variables. Global variables will be kept in bindings_. + // + // This method is called after emitting code for each top-level HLO. The local + // IR values are out of scope at that point and should not be used. + void UnbindAllLocalIrValues(); + + // Returns whether `hlo` is bound to an LLVM IR value. + bool BoundToIrValue(const HloInstruction& hlo) const { + return base_ptrs_.count(&hlo); + } + + llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } + + // A helper method that returns the base pointer of the IrArray for "inst". + llvm::Value* GetBasePointer(const HloInstruction& hlo) const { + auto it = base_ptrs_.find(&hlo); + CHECK(it != base_ptrs_.end()); + return it->second; + } + + // Return the underlying IrArray of the output of the given instruction. + llvm_ir::IrArray GetIrArray(const HloInstruction& hlo); + + private: + // Emits IR to resolve (possibly) recursive GetTupleElement instructions. + llvm::Value* EmitGetTupleElement(const HloInstruction* gte, + llvm::Value* base_ptr); + + // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. + llvm::Value* GetTypedIrValue(const HloInstruction& hlo, + llvm::Value* ir_value); + + const BufferAssignment* buffer_assignment_; + + const TempBufferOffsets* temp_buffer_offsets_; + + const bool is_nested_; + + llvm::IRBuilder<>* ir_builder_; + + // Stores the underlying llvm::IrArray for each HloInstruction. + std::unordered_map base_ptrs_; + + // The address of the memory block that contains all temporary buffers. + llvm::Value* temp_buffer_base_; + + llvm_ir::AliasAnalysis alias_analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc new file mode 100644 index 0000000000..91fd7ae77a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { +namespace gpu { + +namespace { + +bool IsFusile(const HloInstruction& hlo) { + return (hlo.IsElementwise() && hlo.operand_count() > 0) || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kConcatenate || + hlo.opcode() == HloOpcode::kDynamicSlice || + hlo.opcode() == HloOpcode::kDynamicUpdateSlice || + hlo.opcode() == HloOpcode::kFusion || + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kReduceWindow || + hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kSlice || + hlo.opcode() == HloOpcode::kTranspose; +} + +} // namespace + +bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, + int64 operand_index) { + HloInstruction* producer = consumer->mutable_operand(operand_index); + + // Do not fuse to-vector reduction into other consumers. They should be + // unfused or the root of a kInput fusion. + if (IsReductionToVector(*producer)) { + return false; + } + + // We can't fuse library calls, so if a user of such an op could become a + // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for + // further rationale. + if (producer->CouldBeBitcast() && + ImplementedAsLibraryCall(*producer->operand(0))) { + return false; + } + + // We may need to know original operand layout to emit input fusion, and so + // far, we merely use the layout of an operand of the fusion node, which means + // we must fuse only elementwise operations. This restriction should be lifted + // later if we need to fuse other operations, e.g. transpose, for performance. + if ((IsReductionToVector(*consumer) || + (HloOpcode::kFusion == consumer->opcode() && + HloInstruction::FusionKind::kInput == consumer->fusion_kind())) && + !producer->IsElementwise()) { + return false; + } + + return IsFusile(*producer) && IsFusile(*consumer) && + InstructionFusion::ShouldFuse(consumer, operand_index); +} + +HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) { + if (IsReductionToVector(*consumer)) { + return HloInstruction::FusionKind::kInput; + } + if (HloOpcode::kFusion == consumer->opcode()) { + return consumer->fusion_kind(); + } + return InstructionFusion::ChooseKind(producer, consumer); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h new file mode 100644 index 0000000000..21f3b542a2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" + +namespace xla { +namespace gpu { + +class GpuInstructionFusion : public InstructionFusion { + public: + explicit GpuInstructionFusion(bool may_duplicate) + : InstructionFusion(may_duplicate) {} + + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + + HloInstruction::FusionKind ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc new file mode 100644 index 0000000000..c58af04bad --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); + auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(reshape2, computation->root_instruction()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); + auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(transpose2, computation->root_instruction()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { + HloComputation::Builder builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input")); + auto filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter")); + + Window conv_window; + WindowDimension* conv_window_row = conv_window.add_dimensions(); + conv_window_row->set_size(1); + WindowDimension* conv_window_col = conv_window.add_dimensions(); + conv_window_col->set_size(2); + conv_window_col->set_padding_high(1); + + ConvolutionDimensionNumbers conv_dnums; + conv_dnums.set_batch_dimension(0); + conv_dnums.set_feature_dimension(1); + conv_dnums.add_spatial_dimensions(2); + conv_dnums.add_spatial_dimensions(3); + conv_dnums.set_kernel_output_feature_dimension(0); + conv_dnums.set_kernel_input_feature_dimension(1); + conv_dnums.add_kernel_spatial_dimensions(2); + conv_dnums.add_kernel_spatial_dimensions(3); + + auto conv = builder.AddInstruction( + HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), + input, filter, conv_window, conv_dnums)); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0})); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); + + auto module = MakeUnique(TestName()); + module->AddEntryComputation(builder.Build()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, GetTupleElementFused) { + HloComputation::Builder builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, param, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, root->opcode()); + HloInstruction* fused_root = root->fused_expression_root(); + EXPECT_EQ(HloOpcode::kAdd, fused_root->opcode()); + // Check that operands of 'fused_root' are GTE. + EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(0)->opcode()); + EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc new file mode 100644 index 0000000000..0821fb01ab --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" + +#include + +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { +namespace gpu { + +namespace { + +// Return whether the given shape is a matrix with no padding. +bool IsRank2WithNoPadding(const Shape& shape) { + return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape) { + // The inputs and the output must + // 1) be matrices with no padding and a non-zero number of elements, + // 2) have an allowed element type. + bool type_is_allowed = (output_shape.element_type() == F32 || + output_shape.element_type() == F64); + return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && + IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape) && + !ShapeUtil::HasZeroElements(lhs_shape) && + !ShapeUtil::HasZeroElements(rhs_shape); +} +} // namespace + +bool ImplementedAsGemm(const HloInstruction& hlo) { + // For certain types of Dot, we can call pre-canned BLAS gemm. + if (hlo.opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + return true; + } + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && + hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { + return true; + } + + return false; +} + +bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { + // Forward convolution. + if (hlo.opcode() == HloOpcode::kConvolution) { + const ConvolutionDimensionNumbers& dnums = + hlo.convolution_dimension_numbers(); + // Only 2D convolutions are implemented. + // TODO(b/32873825): add support for 3D convolutions using CuDNN. + if (dnums.spatial_dimensions_size() != 2) { + return false; + } + // CuDNN does not accept zero-element arguments + if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) || + ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) { + return false; + } + + return true; + } + + // Backward convolution. + if (hlo.opcode() == HloOpcode::kFusion && + (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter || + hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) { + return true; + } + + return false; +} + +bool ImplementedAsLibraryCall(const HloInstruction& hlo) { + return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo); +} + +bool IsReductionToVector(const HloInstruction& reduce) { + if (HloOpcode::kReduce != reduce.opcode()) { + return false; + } + const HloInstruction* input = reduce.operand(0); + return ShapeUtil::Rank(input->shape()) > 1 && + ShapeUtil::Rank(reduce.shape()) == 1; +} + +// This emits a device-side call to +// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see +// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls +llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, + tensorflow::gtl::ArraySlice arguments, + llvm::IRBuilder<>* builder) { + std::vector argument_types; + for (auto argument : arguments) { + argument_types.push_back(argument->getType()); + } + auto* arguments_type = llvm::StructType::create(argument_types); + llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type); + for (size_t i = 0; i < arguments.size(); ++i) { + builder->CreateStore( + arguments[i], + builder->CreateGEP(arguments_ptr, + {builder->getInt64(0), builder->getInt32(i)})); + } + return builder->CreateCall( + builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( + "vprintf", + llvm::FunctionType::get(builder->getInt32Ty(), + {builder->getInt8Ty()->getPointerTo(), + arguments_type->getPointerTo()}, + /*isVarArg=*/false)), + {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)), + arguments_ptr}); +} + +llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder) { + int bit_width = value->getType()->getPrimitiveSizeInBits(); + + // Special case for efficiency + if (value->getType()->isFloatTy() && bit_width == 32) { + return llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_shfl_down_f32, + {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder); + } + + // We must split values wider than 32 bits as the "shfl" instruction operates + // on 32-bit values. + int num_segments = CeilOfRatio(bit_width, 32); + llvm::Value* x = builder->CreateBitCast( + builder->CreateZExt( + builder->CreateBitCast(value, builder->getIntNTy(bit_width)), + builder->getIntNTy(32 * num_segments)), + llvm::VectorType::get(builder->getInt32Ty(), num_segments)); + for (int i = 0; i < num_segments; ++i) { + x = builder->CreateInsertElement( + x, + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32, + {builder->CreateExtractElement(x, i), + offset, builder->getInt32(kWarpSize - 1)}, + {}, builder), + i); + } + return builder->CreateBitCast( + builder->CreateTrunc( + builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)), + builder->getIntNTy(bit_width)), + value->getType()); +} + +const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo) { + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + hlo = hlo->operand(0); + } + return hlo; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h new file mode 100644 index 0000000000..4d3e9b10b2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_ + +#include + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +namespace gpu { + +const int64 kWarpSize = 32; + +// Precondition: "hlo" is an operand of a Dot instruction. +// +// Returns whether "hlo" is foldable to its user. +bool IsOperandFoldableToDot(const HloInstruction& hlo); + +// Returns true if GpuCompiler can fold any operands of "dot" into "dot" for +// better performance. +bool CanFoldOperandsIntoDot(const HloInstruction& dot); + +// Returns true if `hlo` will be implemented as a call to BLAS gemm. +bool ImplementedAsGemm(const HloInstruction& hlo); + +// Returns true if `hlo` will be implemented as a call to cuDNN convolution. +bool ImplementedAsDnnConvolution(const HloInstruction& hlo); + +// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm +// or cuDNN convolution. +bool ImplementedAsLibraryCall(const HloInstruction& hlo); + +bool IsReductionToVector(const HloInstruction& reduce); + +// Emits call to "vprintf" with given format and arguments. +llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, + tensorflow::gtl::ArraySlice arguments, + llvm::IRBuilder<>* builder); + +// Emits code to shuffle data between threads of a warp. This has the same +// semantics as the PTX "shfl.down" instruction [0] but works for values of any +// size. The last operand of the emitted "shfl" is `kWarpSize - 1`. +// +// [0] +// http://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl +llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder); + +// Resolves GetTupleElement instruction operands starting with 'hlo'. +// Returns the first ancestor instruction which is not a GetTupleElement. +const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc new file mode 100644 index 0000000000..ee1027b8cc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -0,0 +1,645 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Constants.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +using llvm_ir::SetToFirstInsertPoint; + +namespace gpu { + +IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, + IrEmitterContext* ir_emitter_context, bool is_nested) + : ir_emitter_context_(ir_emitter_context), + ir_builder_(ir_emitter_context->llvm_module()->getContext()), + bindings_(ir_emitter_context->hlo_module(), + &ir_emitter_context->buffer_assignment(), + &ir_emitter_context->temp_buffer_offsets(), &ir_builder_, + is_nested), + hlo_module_config_(hlo_module_config) { + llvm::FastMathFlags fast_math_flags; + llvm_ir::SetFastMathFlags(&fast_math_flags); + ir_builder_.setFastMathFlags(fast_math_flags); +} + +Status IrEmitter::DefaultAction(HloInstruction* hlo) { + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : hlo->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + }; + } + return EmitTargetElementLoop( + *hlo, GpuElementalIrEmitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &ir_builder_, GetNestedComputer()) + .MakeElementGenerator(hlo, operand_to_generator)); +} + +Status IrEmitter::HandleConstant(HloInstruction* constant, + const Literal& literal) { + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + *ir_emitter_context_->llvm_module(), initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, + /*Name=*/""); + VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl + << " emitted_value: " << llvm_ir::DumpToString(*global_for_const) + << std::endl + << " its type: " + << llvm_ir::DumpToString(*global_for_const->getType()); + bindings_.BindHloToIrValue(*constant, global_for_const); + return Status::OK(); +} + +Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { + VLOG(2) << "HandleBitcast: " << bitcast->ToString(); + const HloInstruction* operand = bitcast->operand(0); + // Bitcast is a no-op, but we still want to bind it to an llvm::Value + // sometimes, e.g., when it's operand is a constant or a bitcast of a + // constant. + if (bindings_.BoundToIrValue(*operand)) { + bindings_.BindHloToIrValue(*bitcast, bindings_.GetBasePointer(*operand)); + } + return Status::OK(); +} + +Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) { + CHECK(bindings_.BoundToIrValue(*operand)); + bindings_.BindHloToIrValue( + *get_tuple_element, + llvm_ir::EmitGetTupleElement( + get_tuple_element->shape(), get_tuple_element->tuple_index(), + // TODO(b/26344050): tighten the alignment here + // based on the real element type. + /*alignment=*/1, GetBasePointer(*operand), &ir_builder_)); + return Status::OK(); +} + +Status IrEmitter::HandleSort(HloInstruction* sort, + HloInstruction* operand_instruction) { + // TODO(b/26783907): Implement sort on GPU. + return Unimplemented("sort"); +} + +Status IrEmitter::HandleSend(HloInstruction* send) { + return Unimplemented("Send is not implemented on GPU"); +} + +Status IrEmitter::HandleRecv(HloInstruction* recv) { + return Unimplemented("Recv is not implemented on GPU"); +} + +Status IrEmitter::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + std::vector base_ptrs; + for (const HloInstruction* operand : operands) { + base_ptrs.push_back(GetBasePointer(*operand)); + } + llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_); + return Status::OK(); +} + +Status IrEmitter::EmitCallToNestedComputation( + const HloComputation& nested_computation, + tensorflow::gtl::ArraySlice operands, llvm::Value* output) { + TF_RET_CHECK(nested_computation.num_parameters() > 0); + llvm::Function*& emitted_function = + computation_to_ir_function_[&nested_computation]; + if (emitted_function == nullptr) { + IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation, + ir_emitter_context_); + TF_RETURN_IF_ERROR( + nested_computation.root_instruction()->Accept(&ir_emitter_nested)); + emitted_function = ir_emitter_nested.GetEmittedFunction(); + } + + std::vector arguments(operands.begin(), operands.end()); + arguments.push_back(output); + arguments.push_back(bindings_.GetTempBufferBase()); + ir_builder_.CreateCall(emitted_function, arguments); + + return Status::OK(); +} + +bool IrEmitter::MaybeEmitSpecialAtomicOperation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + CHECK_EQ(2, computation.num_parameters()); + + if (computation.instruction_count() != 3) { + // We special-case only computations with one computing instruction for now. + // Such computation has exactly three instructions given it has two + // parameters. + return false; + } + + HloOpcode root_opcode = computation.root_instruction()->opcode(); + PrimitiveType element_type = + computation.root_instruction()->shape().element_type(); + llvm::Value* source = ir_builder_.CreateLoad(source_address, "source"); + if (root_opcode == HloOpcode::kAdd) { + // NVPTX supports atomicAdd on F32 and integer types. + if (element_type == F32) { + // F32 + F32 + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32, + {output_address, source}, + {output_address->getType()}, &ir_builder_); + return true; + } + if (primitive_util::IsIntegralType(element_type)) { + // integral + integral + ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, + source, + llvm::AtomicOrdering::SequentiallyConsistent); + return true; + } + } + + // NVPTX supports atomicMax and atomicMin on only integer types. + if (root_opcode == HloOpcode::kMaximum && + primitive_util::IsIntegralType(element_type)) { + // min(integral, integral) + ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Max, output_address, + source, + llvm::AtomicOrdering::SequentiallyConsistent); + return true; + } + + if (root_opcode == HloOpcode::kMinimum && + primitive_util::IsIntegralType(element_type)) { + // max(integral, integral) + ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Min, output_address, + source, + llvm::AtomicOrdering::SequentiallyConsistent); + return true; + } + + return false; +} + +Status IrEmitter::EmitAtomicOperationForNestedComputation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + if (computation.num_parameters() != 2) { + // TODO(b/30258929): We only accept binary computations so far. + return Unimplemented( + "We only support atomic functions with exactly two parameters, but " + "computation %s has %lld.", + computation.name().c_str(), computation.num_parameters()); + } + + if (MaybeEmitSpecialAtomicOperation(computation, output_address, + source_address)) { + return Status::OK(); + } + + // Other binary computations can be made atomic as following (labels are basic + // block names used in the IR emitting code later). + // + // atomic_op_loop_preheader: + // ... + // source = *source_address; + // old_output = *output_address; + // do { + // atomic_op_loop_body_entry: + // new_output = computation(old_output, source); + // (old_output, success) = + // atomicCAS(output_address, old_output, new_output); + // } while (!success); + // + // atomic_op_loop_exit: + // ... + // + // TODO(jingyue): Consider encapsulate the logic of emitting control flow to + // something similar to llvm_ir::ForLoop. + // + // Emit preparation code to the preheader. + llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); + llvm::Type* element_ir_type = + output_address->getType()->getPointerElementType(); + // old_output = *output_address; + llvm::Value* old_output_location = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, "old_output_location"); + ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"), + old_output_location); + llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( + ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); + + // Emit the body of the loop that repeatedly invokes atomicCAS. + llvm::BasicBlock* loop_body_bb = + llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", + ir_builder_.GetInsertBlock()->getParent()); + ir_builder_.SetInsertPoint(loop_body_bb); + // Change preheader's successor from loop_exit_bb to loop_body_bb. + loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); + // new_output = computation(old_output, source); + llvm::Value* new_output_location = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, "new_output_location"); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + computation, {old_output_location, source_address}, new_output_location)); + + // (old_output, success) = atomicCAS(output_address, old_output, new_output); + llvm::Type* element_int_ir_type = + ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits()); + // cmpxchg accetps integer only, so we bitcast the operands (old_output and + // new_output) to integers of the same bit width, and bitcast the result + // back to the original element type. + llvm::Value* old_output = + ir_builder_.CreateLoad(old_output_location, "old_output"); + llvm::Value* new_output = + ir_builder_.CreateLoad(new_output_location, "new_output"); + llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( + ir_builder_.CreateBitCast(output_address, + element_int_ir_type->getPointerTo()), + ir_builder_.CreateBitCast(old_output, element_int_ir_type), + ir_builder_.CreateBitCast(new_output, element_int_ir_type), + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); + // cmpxchg returns a pair. The first element is the original value at + // output_address and the second element is whether the swap is successful. + ir_builder_.CreateStore( + ir_builder_.CreateBitCast( + ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), + element_ir_type), + old_output_location); + ir_builder_.CreateCondBr( + ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, + loop_body_bb); + + // Restore the insertion point to the exit basic block so that the caller of + // this method can continue emitting code to the right place. + SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); + return Status::OK(); +} + +Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + TF_RET_CHECK(pred->shape().element_type() == PRED); + + if (ShapeUtil::IsTuple(select->shape())) { + llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), + GetBasePointer(*on_true), + GetBasePointer(*on_false), &ir_builder_); + return Status::OK(); + } + + // We must not call the subclass `DefaultAction` method, lest its + // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction` + // assume no handler has already been called. + return IrEmitter::DefaultAction(select); +} + +Status IrEmitter::HandleDot(HloInstruction* dot, + HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction) { + const llvm_ir::IrArray& target_array = GetIrArray(*dot); + const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction); + const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction); + + const Shape& lhs_shape = lhs_instruction->shape(); + const Shape& rhs_shape = rhs_instruction->shape(); + + if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { + // If the operands are scalar, don't emit any loops. + llvm::Value* lhs_value = + lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + llvm::Value* rhs_value = + rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value); + target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); + return Status::OK(); + } + + // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See + // the semantics of Dot in the XLA documentation for details. + TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) && + !ShapeUtil::IsScalar(rhs_shape)); + + // Reduce along the last dimension of the LHS and the second-to-last dimension + // of the RHS. Vectors are a special case where the reduction dimension is 0 + // for both LHS and RHS. This results in a vector dot product producing a + // scalar. + const int64 lhs_reduction_dimension = + ShapeUtil::GetDimensionNumber(lhs_shape, -1); + const int64 rhs_reduction_dimension = + ShapeUtil::Rank(rhs_shape) >= 2 + ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) + : 0; + + // Verify the reduction dimension in the two operands are the same size. + TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == + rhs_shape.dimensions(rhs_reduction_dimension)); + + // Create loop nests which loop through the LHS operand dimensions and the RHS + // operand dimensions. The reduction dimension of the LHS and RHS are handled + // in a separate innermost loop which performs the sum of products. + llvm_ir::ForLoopNest loop_nest(&ir_builder_); + llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( + lhs_array, lhs_reduction_dimension, "lhs", &loop_nest); + llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( + rhs_array, rhs_reduction_dimension, "rhs", &loop_nest); + + // Create the reduction loop which does the sum of products reduction. + std::unique_ptr reduction_loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension), + /*suffix=*/"reduction"); + + // The final entry in the rhs and lhs indexes is the indvar of the reduction + // loop. + lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + + // For computing the sum of products we alloca a single location to store the + // dot product result as we accumulate it within the reduction loop. After the + // reduction loop we load the result and store into the output array. + llvm::Type* accum_type = target_array.GetElementLlvmType(); + llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( + accum_type, // The pointee type of the alloca instruction. + "accum_address", // The name of the alloca instuction. + &ir_builder_); + + // Initialize the accumulator in the preheader to zero. + new llvm::StoreInst( + llvm::ConstantFP::get(accum_type, 0.0), // The value stored. + accum_address, // The address. + reduction_loop->GetPreheaderBasicBlock() + ->getTerminator()); // The instruction this store is inserted before. + + // Emit the body of the reduction loop: + // accum = *accum_address + // updated_accum = accum + lhs_element * rhs_element + // *accum_address = updated_accum + TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); + ir_builder_.SetInsertPoint( + &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); + llvm::Value* lhs_element = + lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_); + llvm::Value* rhs_element = + rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_); + llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); + llvm::Value* accum = ir_builder_.CreateLoad(accum_address); + llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product); + ir_builder_.CreateStore(updated_accum, accum_address); + + // After the reduction loop exits, store the accumulator into the target + // address. The index into the target address is the concatenation of the rhs + // and lhs indexes with the reduction dimensions removed. The terms from the + // rhs index are the lower dimensions in the index so we add them first. + llvm_ir::IrArray::Index target_index; + for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { + if (dimension != lhs_reduction_dimension) { + target_index.push_back(lhs_index[dimension]); + } + } + for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { + if (dimension != rhs_reduction_dimension) { + target_index.push_back(rhs_index[dimension]); + } + } + SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &ir_builder_); + target_array.EmitWriteArrayElement( + target_index, + ir_builder_.CreateLoad( + accum_address), // The value written to the target array. + &ir_builder_); + + // Set the IR builder insert point to the exit basic block of the outer most + // loop. This ensures later instructions are inserted after this loop nest. + ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + + return Status::OK(); +} + +Status IrEmitter::HandleConvolution(HloInstruction* convolution, + HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction, + const Window& window) { + if (ShapeUtil::HasZeroElements(convolution->shape())) { + // Emit no code for an empty output. + return Status::OK(); + } + // TODO(b/31409998): Support convolution with dilation. + return Unimplemented( + "Hit a case for convolution that is not implemented on GPU."); +} + +Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { + // TODO(b/33011107): Support cross replica sum on GPU. + return Unimplemented( + "Cross replica sum not implemented on GPU. See b/33011107."); +} + +Status IrEmitter::HandleParameter(HloInstruction* parameter) { + return Status::OK(); +} + +Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) { + return EmitTargetElementLoop( + *reduce, + [=](const llvm_ir::IrArray::Index& index) -> StatusOr { + // Initialize an accumulator with init_value. + llvm::AllocaInst* accumulator_addr = + ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + reduce->shape().element_type(), &ir_builder_)); + ir_builder_.CreateStore( + ir_builder_.CreateLoad(GetBasePointer(*init_value)), + accumulator_addr); + + // The enclosing loops go over all the target elements. Now we have to + // compute the actual target element. For this, we build a new loop nest + // to iterate over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction + // Value*s are placed for each dimension in dimensions, and all the rest + // are nullptrs. + llvm_ir::ForLoopNest loops(&ir_builder_); + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Build a full index for the input argument, using reduced_dims_index + // as the base. In reduced_dims_index only the reduction dimensions are + // filled in. We fill in the rest of the dimensions with induction + // Value*s taken from 'index' which iterates over the target array. + // See the high-level description in the XLA documentation for details. + llvm_ir::IrArray::Index input_index = reduced_dims_index; + llvm_ir::IrArray::Index::const_iterator it = index.begin(); + + for (int64 i = 0; i < input_index.size(); ++i) { + if (input_index[i] == nullptr) { + input_index[i] = *it++; + } + } + CHECK(index.end() == it); + + // Apply the reduction function to the loaded value. + llvm::Value* input_address = + GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *function, {accumulator_addr, input_address}, accumulator_addr)); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + return ir_builder_.CreateLoad(accumulator_addr); + }); +} + +Status IrEmitter::HandleFusion(HloInstruction* fusion) { + // kFusion for library calls should be handled by + // IrEmitterUnnested::HandleFusion. + CHECK(HloInstruction::FusionKind::kLoop == fusion->fusion_kind()); + + std::vector parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + parameter_arrays.push_back(GetIrArray(*operand)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &ir_builder_, GetNestedComputer()); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); + + return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); +} + +Status IrEmitter::HandleCall( + HloInstruction* call, tensorflow::gtl::ArraySlice operands, + HloComputation* computation) { + std::vector operand_addresses; + for (HloInstruction* operand : operands) { + operand_addresses.push_back(GetBasePointer(*operand)); + } + return EmitCallToNestedComputation(*computation, operand_addresses, + GetBasePointer(*call)); +} + +Status IrEmitter::HandleCustomCall( + HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) { + return Unimplemented("custom-call"); +} + +Status IrEmitter::HandleInfeed(HloInstruction* infeed) { + return Unimplemented("Infeed is not supported on GPU (b/30467474)"); +} + +Status IrEmitter::HandleRng(HloInstruction* random, + RandomDistribution /*distribution*/) { + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : random->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + }; + } + // Emits a single-threaded loop because the loop body generated by the element + // generator for Rng can't be parallelized (b/32333178). + return llvm_ir::LoopEmitter( + GpuElementalIrEmitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &ir_builder_, GetNestedComputer()) + .MakeElementGenerator(random, operand_to_generator), + GetIrArray(*random), &ir_builder_) + .EmitLoop(); +} + +llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( + const llvm_ir::IrArray& operand_array, int64 reduction_dimension, + tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { + // Prepares the dimension list we will use to emit the loop nest. Outermost + // loops are added first. Add loops in major-to-minor order, and skip the + // reduction dimension. + std::vector dimensions; + const Shape& shape = operand_array.GetShape(); + for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { + int64 dimension = shape.layout().minor_to_major(i); + if (dimension != reduction_dimension) { + dimensions.push_back(dimension); + } + } + + // Create loop nest with one for-loop for each dimension of the + // output. + llvm_ir::IrArray::Index index = + loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); + // Verify every dimension except the reduction dimension was set in the index. + for (int dimension = 0; dimension < index.size(); ++dimension) { + if (dimension == reduction_dimension) { + DCHECK_EQ(nullptr, index[dimension]); + } else { + DCHECK_NE(nullptr, index[dimension]); + } + } + return index; +} + +StatusOr IrEmitter::ComputeNestedElement( + const HloComputation& computation, + tensorflow::gtl::ArraySlice parameter_elements) { + llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType( + computation.root_instruction()->shape().element_type(), &ir_builder_), + "return_buffer", &ir_builder_); + std::vector parameter_buffers; + for (llvm::Value* parameter_element : parameter_elements) { + parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( + parameter_element->getType(), "parameter_buffer", &ir_builder_)); + ir_builder_.CreateStore(parameter_element, parameter_buffers.back()); + } + TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, + return_buffer)); + return ir_builder_.CreateLoad(return_buffer); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h new file mode 100644 index 0000000000..0764c61ede --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -0,0 +1,405 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An XLA HLO graph may contain multiple computations. These computations +// fall into two types, nested and unnested. We translate each nested +// computation (e.g. the computation operand of a Map operator) to a device +// function. For each unnested computation composed of top-level +// HloInstructions, we generate a CUDA kernel for each HloInstruction. +// +// This file declares classes that translate an XLA HLO graph to LLVM IR for +// GPUs. IrEmitterNested emits LLVM IR for nested computations, and +// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR +// for each individual HloInstruction is largely the same between these two +// classes. Therefore, we implement the common logic in the Handle* functions in +// the superclass IrEmitter. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ + +#include +#include +#include +#include +#include + +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +// This class is the top-level API for the XLA HLO --> LLVM IR compiler. +// It implements the DfsHloVisitor interface and emits an LLVM IR program that +// implements the input HLO graph. +// +// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in +// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault` +// calls `T::DefaultAction`. +class IrEmitter : public DfsHloVisitorWithDefault { + public: + IrEmitter(const IrEmitter&) = delete; + IrEmitter& operator=(const IrEmitter&) = delete; + + // The following methods implement the DfsHloVisitorWithDefault interface. + Status DefaultAction(HloInstruction* hlo) override; + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override; + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleInfeed(HloInstruction* infeed) override; + Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleParameter(HloInstruction* parameter) override; + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override; + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleCall(HloInstruction* call, + tensorflow::gtl::ArraySlice operands, + HloComputation* computation) override; + Status HandleCustomCall(HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) override; + Status HandleRng(HloInstruction* random, + RandomDistribution /*distribution*/) override; + + Status FinishVisit(HloInstruction* root) override { return Status::OK(); } + + protected: + // Constructs an IrEmitter with the given IrEmitter context. + // ir_emitter_context is owned by the caller and should outlive the IrEmitter + // object. + explicit IrEmitter(const HloModuleConfig& hlo_module_config, + IrEmitterContext* ir_emitter_context, bool is_nested); + + // A convenient helper for calling HloToIrBindings::GetIrArray. + llvm_ir::IrArray GetIrArray(const HloInstruction& inst) { + return bindings_.GetIrArray(inst); + } + // A convenient helper for calling HloToIrBindings::GetBasePointer. + llvm::Value* GetBasePointer(const HloInstruction& inst) const { + return bindings_.GetBasePointer(inst); + } + // A convenient helper for calling BufferAssignment::GetAllocationIndex. + BufferAllocation::Index GetAllocationIndex(const HloInstruction& hlo) const { + return ir_emitter_context_->buffer_assignment() + .GetUniqueTopLevelAllocation(&hlo) + .ConsumeValueOrDie() + ->index(); + } + + // Emit a singlethreaded or multithreaded loop that computes every element in + // the result of the given HLO instruction. This produces a series of nested + // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the + // inner-most loop is provided by the body_emitter function. + virtual Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) = 0; + + // Emits a call in IR to the given nested computation with the given operands + // and output. If no IR function has been previously emitted for the + // computation, also emits such a function. + Status EmitCallToNestedComputation( + const HloComputation& nested_computation, + tensorflow::gtl::ArraySlice operands, llvm::Value* output); + + // Emits an atomic operation that implements `nested_computation` in the + // sequentially consistent memory model. `output_address` and `source_address` + // are the arguments of the nested computation. For example, + // atomicAdd(output_address, *source_address). + Status EmitAtomicOperationForNestedComputation( + const HloComputation& nested_computation, llvm::Value* output_address, + llvm::Value* source_address); + + GpuElementalIrEmitter::NestedComputer GetNestedComputer() { + return std::bind(&IrEmitter::ComputeNestedElement, this, + std::placeholders::_1, std::placeholders::_2); + } + + IrEmitterContext* ir_emitter_context_; + + // The following fields track the IR emission state. According to LLVM memory + // management rules, their memory is owned by the module. + llvm::IRBuilder<> ir_builder_; + + // Mapping from HLO to its underlying LLVM value. + HloToIrBindings bindings_; + + // Hlo configuration data used during code generation. + const HloModuleConfig& hlo_module_config_; + + private: + // Emits a series of nested loops for iterating over an operand array in the + // dot operation. Loops are constructed in major to minor dimension layout + // order. No loop is emitted for the given reduction_dimension. The function + // returns an IrArray index for the given operand_array containing the indvars + // of the loops. All dimensions of the index are filled except for the + // reduction dimension. name_suffix is the string to append to the names of + // LLVM constructs (eg, basic blocks) constructed by this method. + llvm_ir::IrArray::Index EmitOperandArrayLoopNest( + const llvm_ir::IrArray& operand_array, int64 reduction_dimension, + tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest); + + // A helper method for EmitAtomicOperationForNestedComputation. Certain + // computations, such as floating-point addition and integer maximization, can + // be simply implemented using an LLVM atomic instruction. If "computation" is + // one of this kind, emits code to do that and returns true; otherwise, + // returns false. + bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); + + StatusOr ComputeNestedElement( + const HloComputation& computation, + tensorflow::gtl::ArraySlice parameter_elements); + + // Emits an atomic operation that implements `nested_computation` in the + // sequentially consistent memory model. `output_address` and `source_address` + // are the arguments of the nested computation. For example, + // atomicAdd(output_address, *source_address). + StatusOr EmitAtomicFunctionForNestedComputation( + const HloComputation& nested_computation, llvm::Type* element_ir_type); + + // Map nested computations to emitted IR functions. This serves as a cache so + // that IrEmitter does not emit multiple functions for the same + // HloComputation. + std::map computation_to_ir_function_; +}; + +// Emits LLVM IR for unnested computations. Each HloInstruction is translated to +// a separate CUDA kernel. These kernels are inserted into the resultant module +// sorted in reverse postorder of the XLA HLO graph. +class IrEmitterUnnested : public IrEmitter { + public: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + bool has_hybrid_result, + IrEmitterContext* ir_emitter_context); + IrEmitterUnnested(const IrEmitterUnnested&) = delete; + IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + + // Transfers the ownship of thunk_sequence_ out. + std::unique_ptr ConsumeThunkSequence() { + return std::move(thunk_sequence_); + } + + Status DefaultAction(HloInstruction* hlo) override; + + // IrEmitterUnnested handles the following instructions differently from + // IrEmitter. + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override; + Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, + HloComputation* condition, HloComputation* body) override; + Status HandleRng(HloInstruction* random, + RandomDistribution distribution) override; + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override; + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + // Same as `EmitTargetElementLoop`, but in given `thunk` rather than + // `LastThunk()`. + Status EmitTargetElementLoopInThunk( + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, + KernelThunk* thunk); + + private: + // Builds the appropriate thunk for the instruction hlo and returns the owning + // pointer to it. The caller needs to make sure `inst` outlives the lifetime + // of the returned Thunk object. + std::unique_ptr BuildThunk(const HloInstruction* hlo); + + // Builds the prototype of the IR kernel for `inst` and adds it to the module. + llvm::Function* BuildKernelPrototype( + const HloInstruction& inst, + tensorflow::gtl::ArraySlice escaped_hlos); + + // Emits the base pointers for `hlo` and its operands. `io_hlos` will store + // all input/output HLOs among `hlo` and its operands. + llvm::Function* EmitBasePointersForHloAndItsOperands( + const HloInstruction& hlo, std::vector* io_hlos); + + // EmitColumnReduction and EmitRowReduction emit code for column and row + // reduction of a matrix and/or 3D tensor. Row and column reduction have + // different memory access pattern, so for performance their implementations + // are significantly different. + // + // Emits code that reduces a matrix of shape [height x width] to a vector of + // [width]. Other parameters have the same meaning as those of + // `EmitReductionToVector`. Note that input shape might not be + // [height x width], but can be bitcast to [height x weight] with "height" + // being the major dimension. + Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a 3D tensor of shape [depth x height x width] to a + // vector of shape [height]. Other parameters have the same meaning as those + // of `EmitReductionToVector`. Note that input shape might not be + // [depth x height x width], but can be bitcast to [depth x height x weight] + // with "depth" being the most major dimension. + Status EmitRowReduction(int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Figures out whether `reduce` is a row or column reduction, and which + // dimensions to reduce, and calls either `EmitRowReduction` or + // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the + // input array, which is the operand of the Reduce instruction if unfused or + // of the Fusion instruction if fused. `input_gen` and `init_value_gen` + // generate elements of the input and the initial value. Other parameters mean + // the same as for `HandleReduce`. + // + // Prerequisite: `IsReductionToVector(*reduce)` + Status EmitReductionToVector( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer); + + // Emits code to initialize buffer of `inst` in given `thunk`. + Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); + + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The + // caller needs to make sure `inst` outlives the lifetime of the returned + // Thunk object. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + + // Returns a ConvolutionThunk that calls DNN to implement `inst`. + std::unique_ptr BuildConvolutionThunk(const HloInstruction* inst); + + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs + // to make sure `inst` outlives the lifetime of the returned Thunk object. + std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + + // Returns a CopyThunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildCopyThunk(const HloInstruction* inst); + + // Returns a WhileThunk that invokes thunk sequences for 'condition' and + // 'body' sub-computations of while instruction 'hlo'. + std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + + // Returns a ForThunk which executes 'loop_limit' invocations of a thunk + // sequence from the 'body' sub-computation of the while instruction 'hlo'. + std::unique_ptr BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); + + Status Postprocess(HloInstruction* hlo) override; + + // Returns the last generated thunk. + Thunk* LastThunk() const { return thunk_sequence_->back().get(); } + + // The thunk sequence this IrEmitter generates for the input computation. + std::unique_ptr thunk_sequence_; + + // The HloComputation that this IrEmitter emits code for. + const HloComputation* hlo_computation_; + + // Whether this computation will produce a hybrid result, that is the + // computation produces a ShapedBuffer. + bool has_hybrid_result_; +}; + +// Emits LLVM IR for a nested computation to the resultant function. +class IrEmitterNested : public IrEmitter { + public: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; + IrEmitterNested& operator=(const IrEmitterNested&) = delete; + + // Overrides the default empty implementation. Binds the given instruction + // "parameter" with the parameter of the IR function. + Status HandleParameter(HloInstruction* parameter) override; + + llvm::Function* GetEmittedFunction() const { return emitted_function_; } + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + private: + llvm::Function* EmitBasePointersForNestedComputation( + const HloComputation& nested_computation, + std::vector* io_hlos); + + llvm::Function* emitted_function_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h new file mode 100644 index 0000000000..b204d9625c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ + +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// IrEmitterContext encapsulates common (mutable and immutable) data structures +// used by both IrEmitterNested and IrEmitterUnnested, such as the buffer +// assignment and the name uniquer. +class IrEmitterContext { + public: + IrEmitterContext(const HloModule* hlo_module, + const BufferAssignment* buffer_assignment, + const TempBufferOffsets* temp_buffer_offsets, + const perftools::gputools::DeviceDescription* device_desc, + llvm::Module* llvm_module) + : hlo_module_(hlo_module), + buffer_assignment_(buffer_assignment), + temp_buffer_offsets_(temp_buffer_offsets), + device_desc_(device_desc), + llvm_module_(llvm_module) {} + // Disallow copy and assign. + IrEmitterContext(const IrEmitterContext&) = delete; + IrEmitterContext& operator=(const IrEmitterContext&) = delete; + + // Simple accessors. + const HloModule& hlo_module() const { return *hlo_module_; } + const BufferAssignment& buffer_assignment() const { + return *buffer_assignment_; + } + const TempBufferOffsets& temp_buffer_offsets() const { + return *temp_buffer_offsets_; + } + const perftools::gputools::DeviceDescription& device_description() const { + return *device_desc_; + } + llvm::Module* llvm_module() { return llvm_module_; } + NameUniquer* name_uniquer() { return &name_uniquer_; } + + private: + const HloModule* hlo_module_; + const BufferAssignment* buffer_assignment_; + const TempBufferOffsets* temp_buffer_offsets_; + const perftools::gputools::DeviceDescription* device_desc_; + llvm::Module* llvm_module_; + NameUniquer name_uniquer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc new file mode 100644 index 0000000000..dc5e2d8f02 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace gpu { + +IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context) + : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) { + std::vector io_hlos; + emitted_function_ = + EmitBasePointersForNestedComputation(nested_computation, &io_hlos); +} + +llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( + const HloComputation& nested_computation, + std::vector* io_hlos) { + std::vector argument_types; + std::vector argument_dereferenceable_bytes; + for (const HloInstruction* param : + nested_computation.parameter_instructions()) { + io_hlos->push_back(param); + const Shape& param_shape = param->shape(); + argument_types.push_back( + llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo()); + int64 param_size = llvm_ir::ByteSizeOf( + param_shape, ir_emitter_context_->llvm_module()->getDataLayout()); + argument_dereferenceable_bytes.push_back(param_size); + } + { + const HloInstruction* root = nested_computation.root_instruction(); + io_hlos->push_back(root); + const Shape& root_shape = root->shape(); + argument_types.push_back( + llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo()); + int64 root_size = llvm_ir::ByteSizeOf( + root_shape, ir_emitter_context_->llvm_module()->getDataLayout()); + argument_dereferenceable_bytes.push_back(root_size); + } + // The base pointer of the memory block for all pre-allocated temp buffers. + argument_types.push_back(ir_builder_.getInt8PtrTy()); + + llvm::FunctionType* function_type = + llvm::FunctionType::get(ir_builder_.getVoidTy(), argument_types, false); + llvm::Function* function = llvm::Function::Create( + function_type, // The function type. + llvm::GlobalValue::InternalLinkage, // The linkage type. + llvm_ir::AsStringRef(ir_emitter_context_->name_uniquer()->GetUniqueName( + llvm_ir::SanitizeIrName( + nested_computation.name()))), // The name of the function. + ir_emitter_context_->llvm_module()); // The parent LLVM module. + for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size(); + ++arg_no) { + int64 arg_size = argument_dereferenceable_bytes[arg_no]; + if (arg_size > 0) { + function->addDereferenceableAttr(arg_no + 1, arg_size); + } + } + + llvm::BasicBlock* entry_bb = + llvm::BasicBlock::Create(function->getContext(), "entry", function); + // Emit a "return void" at entry_bb's end, and sets the insert point before + // that return instruction. + ir_builder_.SetInsertPoint( + llvm::ReturnInst::Create(function->getContext(), entry_bb)); + + std::vector non_io_hlos; + for (const auto& hlo : nested_computation.instructions()) { + if (hlo->opcode() != HloOpcode::kParameter && + hlo.get() != nested_computation.root_instruction()) { + non_io_hlos.push_back(hlo.get()); + } + } + bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); + return function; +} + +Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { + return Status::OK(); +} + +Status IrEmitterNested::EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& element_generator) { + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_) + .EmitLoop(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc new file mode 100644 index 0000000000..79a6443346 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -0,0 +1,1745 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "external/llvm/include/llvm/ADT/StringRef.h" +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/for_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +namespace { + +// If a dimensions is smaller than this, untiled transposition may be more +// efficient. +const int64 kMinDimensionToTransposeTiled = 16; + +// Returns true if all paths from `hlo` to `root` contain only tuples. The +// result of such an HloInstruction does not need to be materialized, when the +// computation can have a hybrid result. +bool ReachRootViaOnlyTuples(const HloInstruction& hlo, + const HloInstruction& root) { + if (hlo.opcode() != HloOpcode::kTuple) { + return false; + } + + if (&hlo == &root) { + return true; + } + + for (HloInstruction* user : hlo.users()) { + if (!ReachRootViaOnlyTuples(*user, root)) { + return false; + } + } + + return true; +} + +// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. +const HloInstruction* StripTranspose(const HloInstruction& hlo) { + if (hlo.IsRank2Transpose()) { + return hlo.operand(0); + } + return &hlo; +} + +// Updates the launch dimensions in "thunk" and annotate the launch dimensions +// of the corresponding IR kernel in "llvm_module". +// Precondition: "thunk" must be a KernelThunk. +void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, + llvm::Module* llvm_module) { + CHECK(Thunk::Kind::kKernel == thunk->kind()); + KernelThunk* kernel_thunk = static_cast(thunk); + kernel_thunk->SetLaunchDimensions(launch_dims); + + // Add __launch_bounds__ to metadata. This limits registers per thread to + // avoid out-of-resources launching errors. + llvm::NamedMDNode* nvvm_annotations_node = + llvm_module->getOrInsertNamedMetadata("nvvm.annotations"); + llvm::Function* ir_kernel = + llvm_module->getFunction(kernel_thunk->kernel_name().c_str()); + llvm::LLVMContext& llvm_context = llvm_module->getContext(); + llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( + llvm::IntegerType::get(llvm_context, /*NumBits=*/32), + launch_dims.threads_per_block()); + nvvm_annotations_node->addOperand(llvm::MDNode::get( + llvm_context, + {llvm::ConstantAsMetadata::get(ir_kernel), + llvm::MDString::get(llvm_context, "maxntidx"), + llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); +} +} // namespace + +IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + bool has_hybrid_result, + IrEmitterContext* ir_emitter_context) + : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), + hlo_computation_(hlo_computation), + has_hybrid_result_(has_hybrid_result) { + // Initialize thunk_sequence_ to an empty list of thunks. + thunk_sequence_.reset(new ThunkSequence()); +} + +Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { + bindings_.UnbindAllLocalIrValues(); + return DfsHloVisitor::Postprocess(hlo); +} + +namespace { +bool ImplementedAsMemcpy(const HloInstruction& hlo) { + // `hlo` needs to satisfy three conditions to be implemented as a + // host-to-device cuMemcpy. + // + // 1. `hlo` is a kCopy instruction. + // 2. `hlo`'s only operand is a kConstant instruction. + // 3. `hlo` and its operand have the same shape (thus the same layout too). + return hlo.opcode() == HloOpcode::kCopy && + hlo.operand(0)->opcode() == HloOpcode::kConstant && + ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()); +} +} // namespace + +llvm::Function* IrEmitterUnnested::BuildKernelPrototype( + const HloInstruction& inst, + tensorflow::gtl::ArraySlice escaped_hlos) { + // Compute the kernel name. The opcode string may contain "-" which cannot be + // in a PTX function name, so sanitize the name before uniquifying it. + string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( + llvm_ir::SanitizeIrName(inst.name())); + + // Create the kernel and adds it to the module. + llvm::Module* module = ir_emitter_context_->llvm_module(); + llvm::LLVMContext& context = module->getContext(); + int num_escaped_hlos = escaped_hlos.size(); + llvm::FunctionType* kernel_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), // The type of function result. + std::vector(num_escaped_hlos + 1, + ir_builder_.getInt8PtrTy()), + false); // Not a variadic argument function. + llvm::Function* kernel = + llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, + kernel_name.c_str(), module); + + // Add dereferenceable information to each of the escaped HLO parameters. + for (size_t arg_no = 0; arg_no < escaped_hlos.size(); ++arg_no) { + const HloInstruction* escaped_hlo = escaped_hlos[arg_no]; + const Shape& escaped_hlo_shape = escaped_hlo->shape(); + int64 escaped_hlo_size = llvm_ir::ByteSizeOf( + escaped_hlo_shape, ir_emitter_context_->llvm_module()->getDataLayout()); + kernel->addDereferenceableAttr(arg_no + 1, escaped_hlo_size); + } + + // The last argument is a pointer to the temporary buffer memory block. + // We know that it doesn't alias any of the escaped arguments (the inputs + + // the result). We also know how many bytes can be dereferenced in it. + const llvm::Argument& temp_buffer = kernel->getArgumentList().back(); + int64 temp_buffer_size = + ir_emitter_context_->temp_buffer_offsets().TotalSizeInBytes(); + int64 temp_buffer_arg_no = temp_buffer.getArgNo(); + if (temp_buffer_size > 0) { + kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, temp_buffer_size); + } + kernel->setDoesNotAlias(temp_buffer_arg_no + 1); + + // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX + // treats it as a CUDA kernel. + llvm::NamedMDNode* nvvm_annotations_node = + module->getOrInsertNamedMetadata("nvvm.annotations"); + nvvm_annotations_node->addOperand(llvm::MDNode::get( + context, {llvm::ConstantAsMetadata::get(kernel), + llvm::MDString::get(context, "kernel"), + llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))})); + + // Update the insert point to the entry basic block. + llvm::BasicBlock* entry_bb = + llvm::BasicBlock::Create(context, + "entry", // The name of the basic block. + kernel); // The parent/owner of "entry_bb". + // Emit a "return void" at entry_bb's end, and sets the insert point before + // that return instruction. + ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); + + return kernel; +} + +Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { + thunk_sequence_->emplace_back(BuildKernelThunk(hlo)); + return IrEmitter::DefaultAction(hlo); +} + +Status IrEmitterUnnested::HandleDot(HloInstruction* dot, + HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction) { + if (ImplementedAsGemm(*dot)) { + thunk_sequence_->emplace_back(BuildGemmThunk(dot)); + return Status::OK(); + } + thunk_sequence_->emplace_back(BuildKernelThunk(dot)); + return IrEmitter::HandleDot(dot, lhs_instruction, rhs_instruction); +} + +Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution, + HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction, + const Window& window) { + if (ImplementedAsDnnConvolution(*convolution)) { + thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); + return Status::OK(); + } + thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); + return IrEmitter::HandleConvolution(convolution, lhs_instruction, + rhs_instruction, window); +} + +Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { + HloInstruction* root = fusion->fused_expression_root(); + // HandleFusion specializes reduction from a multi-dimensional array to a 1D + // array. The specialized version requires a initializer thunk that + // initializes the output array to the initial value of the reduce. + if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { + switch (root->opcode()) { + case HloOpcode::kReduce: { + VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); + std::vector> thunks; + thunks.emplace_back(BuildKernelThunk(fusion)); + TF_RETURN_IF_ERROR(EmitInitializer( + fusion, static_cast(thunks.back().get()))); + bindings_.UnbindAllLocalIrValues(); + thunks.emplace_back(BuildKernelThunk(fusion)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), fusion)); + std::vector parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + parameter_arrays.push_back(GetIrArray(*operand)); + } + GpuElementalIrEmitter elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), + &ir_builder_, GetNestedComputer()); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); + + Shape input_shape = root->operand(0)->shape(); + // EmitRedutionToVector requires the input shape to have a layout, but + // fused instructions don't have one. So we determine its layout from + // the fusion's operands. The choice of the layout only affects + // performance but not correctness. + auto choose_input_layout = []( + tensorflow::gtl::ArraySlice operands, + Shape* input_shape) { + // Prefer the layout of an operand whose shape is compatible with + // input_shape. + for (const HloInstruction* operand : operands) { + if (ShapeUtil::Compatible(*input_shape, operand->shape())) { + LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), + input_shape); + return; + } + } + // If no operand has a compatible shape, prefer an operand that has + // the same rank at least. + for (const HloInstruction* operand : operands) { + if (ShapeUtil::Rank(*input_shape) == + ShapeUtil::Rank(operand->shape())) { + // Do not use CopyLayoutBetweenShapes because input_shape and + // operand->shape() may be incompatible. + *input_shape->mutable_layout() = operand->shape().layout(); + return; + } + } + // When all the above fails, which is rare, set the default layout. + LayoutUtil::SetToDefaultLayout(input_shape); + }; + choose_input_layout(fusion->operands(), &input_shape); + + return EmitReductionToVector( + root, input_shape, fused_emitter.GetGenerator(root->operand(0)), + fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), + root->to_apply()); + break; + } + default: + LOG(FATAL) << "Bad opcode for input fusion: " + << fusion->fused_expression_root()->opcode(); + } + } + if (ImplementedAsGemm(*fusion)) { + thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); + return Status::OK(); + } + if (ImplementedAsDnnConvolution(*fusion)) { + thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion)); + return Status::OK(); + } + thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); + return IrEmitter::HandleFusion(fusion); +} + +namespace { + +// Returns the indices of the first elements of all consecutive subarrays of the +// given array. For example: +// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} +std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { + std::vector is = {0}; + for (size_t i = 1; i < xs.size(); ++i) { + if (1 != xs[i] - xs[i - 1]) { + is.push_back(i); + } + } + return is; +} + +// Merges the sequences of dimensions of the given shape which start at the +// given indices `segs`. +Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, + const Shape& shape) { + std::vector dimensions; + for (size_t i = 1; i <= segs.size(); ++i) { + dimensions.push_back(std::accumulate( + shape.dimensions().begin() + segs[i - 1], + shape.dimensions().begin() + + (segs.size() == i ? shape.dimensions().size() : segs[i]), + 1, std::multiplies())); + } + return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), + dimensions); +} + +// Returns whether the given shapes and permutation are a 0-2-1 transpose, and +// if so, the normalized and rank-reduced shapes. The shapes must have the same +// dimensions, so this considers layout only. +// +// This function recognizes higher-rank transposes which are elementwise +// equivalent to a 0-2-1 transpose. +std::tuple IsTranspose021(const Shape& a, const Shape& b) { + CHECK(ShapeUtil::Compatible(a, b)); + std::vector perm(a.dimensions().size()); + { + std::vector layout_a(a.layout().minor_to_major().rbegin(), + a.layout().minor_to_major().rend()); + std::vector layout_b(b.layout().minor_to_major().rbegin(), + b.layout().minor_to_major().rend()); + for (size_t i = 0; i < perm.size(); ++i) { + perm[i] = PositionInContainer(layout_b, layout_a[i]); + } + } + auto segs = ConsecutiveSegments(perm); + Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a); + Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b); + if (3 == segs.size() && 0 == perm[0]) { + Shape reduced_a = MergeDimensions(segs, norm_a); + Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + b.element_type(), + Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); + return std::make_tuple(true, reduced_a, reduced_b); + } + return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil()); +} + +// Returns whether the given shapes are potentially of a 0-2-1 transpose. +// As 0-2-1 is a self-inverse permutation, which shape is input or output is +// arbitrary. +bool AreShapesForTranspose021(const Shape& a, const Shape& b) { + return 3 == b.dimensions().size() && + ShapeUtil::Compatible( + ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a), + ShapeUtil::PermuteDimensions( + {0, 2, 1}, + ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b))); +} + +// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from +// major to minor. The x- and y- dimensions are tiled in square tiles of edge +// length `tile_size`. Each thread block of `tile_size` threads transposes one +// tile: each thread copies a row from the input to a shared memory tile, then +// copies a column from the shared memory tile to the output. +// +// `tile_size` should usually be same as warp size. +// +// Returns (number of tiles = number of thread blocks needed). +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient +// to launch fewer blocks so each transposes many tiles, and +// in any case, the number of blocks we can launch is limited. +// +// This is the same algorithm in CUDA: +// https://github.com/tensorflow/tensorflow/blob/6172351b81af76d0b819fea6bb478cbd4016d6c2/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L183 +int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, + const int64 tile_size, llvm::IRBuilder<>* builder) { + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [builder](llvm_ir::IrArray::Index index, + llvm::Value* addend, int64 dim) { + index[dim] = builder->CreateAdd(index[dim], addend); + return index; + }; + + CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); + + Shape input_shape = + ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape()); + Shape output_shape = + ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape()); + input = input.CastToShape(input_shape, builder); + output = output.CastToShape(output_shape, builder); + + llvm::Type* tile_type = llvm::ArrayType::get( + llvm::ArrayType::get(input.GetElementLlvmType(), tile_size), + // One extra here to avoid share memory bank conflict + tile_size + 1); + auto* tile = new llvm::GlobalVariable( + *builder->GetInsertBlock()->getParent()->getParent(), tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), "tile", nullptr, + llvm::GlobalValue::NotThreadLocal, + /*AddressSpace=*/3 /* GPU shared memory */); + + // let x = threadIdx.x + llvm::Value* x = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, tile_size, static_cast(x)); + x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, + "thread.id.x"); + + // `emit_cp` emits equivalent to following pseudocode: + // if (tile_size == tile_width && tile_size == tile_height) { + // unroll for (y in 0..tile_size) { + // emit_cp_element(index + {0, y, 0}, y); + // } + // } else if (x < tile_width) { + // for (y in 0..tile_height) { + // emit_cp_element(index + {0, y, 0}, y); + // } + // } + // + // We use this to emit both the copy from input to tile and the copy from tile + // to output. + // + // `index` is the origin of the row or column in the input or output array. + // + // `emit_cp_element(index, y)` emits code to copy a single element between the + // tile and the input or output array, where `y` is the `y`-position in the + // tile, whether which is row or column is a function of whether we're copying + // from input or to output, and `index` is the index into the input or output + // array. + auto emit_cp_tile = [builder, tile_size, x, &offset_dim]( + std::function + emit_cp_element, + llvm::Value* tile_width, llvm::Value* tile_height, + const llvm_ir::IrArray::Index& index, const string& loop_name) { + llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( + builder->CreateAnd( + builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), + builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), + "not_last_row", builder); + builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); + for (int64 i = 0; i < tile_size; ++i) { + emit_cp_element(offset_dim(index, builder->getInt64(i), /*dim=*/1), + builder->getInt64(i)); + } + builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); + llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( + builder->CreateICmpULT(x, tile_width), "in_tile", builder); + builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); + auto loop = llvm_ir::ForLoop::EmitForLoop(loop_name, builder->getInt64(0), + tile_height, builder->getInt64(1), + builder); + llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); + builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); + emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), + loop->GetIndVarValue()); + builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); + }; + + auto input_dims_in_tiles = input_shape.dimensions(); + // Unpermuted dimensions are untiled. + for (int i = 1; i < 3; ++i) { + input_dims_in_tiles[i] = + CeilOfRatio(input_dims_in_tiles[i], tile_size); + } + int64 num_tiles = + std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1, + std::multiplies()); + const llvm_ir::IrArray::Index input_tile_index( + /*linear=*/builder->CreateIntCast( + llvm_ir::AddRangeMetadata( + 0, num_tiles, + static_cast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, + builder))), + builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), + builder); + const llvm_ir::IrArray::Index input_tile_origin = ({ + llvm_ir::IrArray::Index index = input_tile_index; + for (int i = 1; i < 3; ++i) { + index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size), + "tile_origin." + std::to_string(i)); + } + index; + }); + const llvm_ir::IrArray::Index input_index = + offset_dim(input_tile_origin, x, /*dim=*/2); + std::vector tile_dims(input_shape.dimensions().size()); + // Only last row or column may not have full size. + for (int i = 1; i < 3; ++i) { + tile_dims[i] = builder->CreateSelect( + builder->CreateICmpEQ(input_tile_index[i], + builder->getInt64(input_dims_in_tiles[i] - 1)), + builder->getInt64(input_shape.dimensions(i) - + (input_dims_in_tiles[i] - 1) * tile_size), + builder->getInt64(tile_size), "tile_size"); + } + + // Load data from input memory to shared memory tile. + emit_cp_tile( + // tile[y, x] = input_array[index] + [builder, tile, x, &input](const llvm_ir::IrArray::Index& index, + llvm::Value* y) { + builder->CreateStore( + input.EmitReadArrayElement(index, builder, "input_element"), + builder->CreateGEP(tile, {builder->getInt64(0), y, x})); + }, + tile_dims[2], tile_dims[1], input_index, "input"); + + // Wait for all threads to reach this point, lest we copy a value from tile to + // output before the other thread copies it from input to tile. + // This is `__syncthreads` in CUDA. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder); + + const llvm_ir::IrArray::Index output_tile_index( + Permute({0, 2, 1}, input_tile_index.multidim())); + const llvm_ir::IrArray::Index output_tile_origin( + Permute({0, 2, 1}, input_tile_origin.multidim())); + const llvm_ir::IrArray::Index output_index = + offset_dim(output_tile_origin, x, /*dim=*/2); + + // Store data from shared memory tile to output memory. + emit_cp_tile( + // output_array[index] = tile[x, y] + [builder, tile, x, &output](const llvm_ir::IrArray::Index& index, + llvm::Value* y) { + output.EmitWriteArrayElement( + index, + builder->CreateLoad( + builder->CreateGEP(tile, {builder->getInt64(0), x, y}), + "output_element"), + builder); + }, + tile_dims[1], tile_dims[2], output_index, "output"); + + return num_tiles; +} + +} // namespace + +Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, + HloInstruction* operand) { + if (ImplementedAsMemcpy(*copy)) { + thunk_sequence_->emplace_back(BuildCopyThunk(copy)); + return Status::OK(); + } + bool is_transpose_021; + Shape reduced_input_shape, reduced_output_shape; + std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = + IsTranspose021(operand->shape(), copy->shape()); + if (is_transpose_021 && + reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && + reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { + thunk_sequence_->emplace_back(BuildKernelThunk(copy)); + VLOG(3) << "Emitting tiled 0-2-1 transposition"; + constexpr int64 tile_size = 32; + int64 num_tiles = EmitTranspose021Tiled( + GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_), + GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), + tile_size, &ir_builder_); + UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(), + ir_emitter_context_->llvm_module()); + return Status::OK(); + } + + return IrEmitter::HandleCopy(copy, operand); +} + +Status IrEmitterUnnested::EmitColumnReduction( + int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + // Divide the input matrix into tiles of size Kx1. For example, when the + // input matrix is 4x4 and K=2, the tiled matrix looks like + // + // 0123 + // 0123 + // 4567 + // 4567 // Numbers indicate tile IDs. + // + // Each tile is first partially reduced to a scalar by a thread, and then the + // scalar is accumulated to the output vector using atomic operations. We + // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. + constexpr int64 kTileSize = 16; + // If the height is not a multiple of the tile size, we pad the bottom of the + // input matrix. + const int64 height_in_tiles = CeilOfRatio(height, kTileSize); + + // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; + // linear_index < height_in_tiles * width; + // linear_index += blockDim.x * gridDim.x) { + // y_in_tiles = linear_index / width; + // x = linear_index % width; + // + // partial_result = init_value; + // if (height % kTileSize == 0 || + // y_in_tiles * kTileSize + kTileSize <= height) { + // for (element_id_in_tile : range(kTileSize)) { + // y = y_in_tiles * kTileSize + element_id_in_tile; + // partial_result = Reducer(partial_result, input[y][x]); + // } + // } else { + // for (element_id_in_tile : range(kTileSize)) { + // y = y_in_tiles * kTileSize + element_id_in_tile; + // if (y < height) { + // partial_result = Reducer(partial_result, input[y][x]); + // } + // } + // } + // AtomicReducer(&output[x], partial_result); + // } + auto loop_body_emitter = + [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + // Emit the loop body that reduces one tile. + llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( + input_shape.element_type(), &ir_builder_); + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); + { + TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, + init_value_gen(llvm_ir::IrArray::Index({}))); + ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + } + + // Emit an inner for-loop that partially reduces the elements in the given + // tile. + llvm::Value* y_in_tiles = tile_index[0]; + llvm::Value* x = tile_index[1]; + + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + std::unique_ptr tile_element_loop = + llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", + ir_builder_.getInt64(0), + ir_builder_.getInt64(kTileSize), + ir_builder_.getInt64(1), &ir_builder_); + + // Emit the body of the partial reduction loop. + llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), + &ir_builder_); + llvm::Value* y = ir_builder_.CreateNSWAdd( + ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)), + tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a + // y-in-bounds check before reading from the input. + if (!tile_in_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)), + "y_in_bounds", &ir_builder_); + + // Emit code that reads the input element and accumulates it to + // the partial reduction result. + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + } + llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + { + // {y,x} is an index to input_matrix_shape [height,width]. We need to + // convert that to an index to input_shape (the shape of the operand of + // "reduce"). This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_matrix_shape. + const Shape normalized_input_shape = + ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + const std::vector transpose_dimension_mapping( + input_shape.layout().minor_to_major().rbegin(), + input_shape.layout().minor_to_major().rend()); + + const Shape input_matrix_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), {height, width}); + const llvm_ir::IrArray::Index input_matrix_index( + {y, x}, input_matrix_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_matrix_index + .SourceIndexOfReshape(input_matrix_shape, + normalized_input_shape, &ir_builder_) + .SourceIndexOfTranspose(normalized_input_shape, input_shape, + transpose_dimension_mapping, + &ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, + input_gen(input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + } + return (EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, input_address}, + partial_reduction_result_address)); + }; + + // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's + // immediately beyond the tile. + llvm::Value* y_end = ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(kTileSize), + ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize))); + llvm::Value* tile_in_bounds = ir_builder_.CreateOr( + ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)), + ir_builder_.getInt1(height % kTileSize == 0)); + // The tile is entirely in bound if "height" is a multiple of kTileSize or + // y_end <= height. + llvm_ir::LlvmIfData if_tile_in_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); + + // After the if-then-else statement on tile_in_bounds, emit atomic + // operations to accumulate the partial reduction result to the output + // element. + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, + &ir_builder_); + const HloInstruction* output = + reduce->IsFused() ? reduce->fusion_instruction() : reduce; + llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( + llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, + "output_element_address"); + return EmitAtomicOperationForNestedComputation( + *reducer, output_address, partial_reduction_result_address); + }; + + // Emit a parallel loop that iterate through all input tiles. + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); + UpdateLaunchDimensions( + launch_dimensions, + static_cast(LastThunk())->thunks().back().get(), + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &ir_builder_) + .EmitLoop(); +} + +Status IrEmitterUnnested::EmitRowReduction( + int64 depth, int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + // A naive algorithm is: + // 1. Divide the input tensor into tiles of size 1x1xK. + // 2. Partially reduces each tile to a scalar using one thread. + // 3. Accumulates that scalar to the output vector using atomic operations. + // + // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; + // linear_index < depth * height * width_in_tiles; + // linear_index += blockDim.x * gridDim.x) { + // int x_in_tiles = linear_index % width_in_tiles; + // int y = linear_index / width_in_tiles % height; + // int z = linear_index / (height * width_in_tiles); + // float partial_result = 0; + // for (element_id_in_tile : range(kTileSize)) { + // int x = x_in_tiles * kTileSize + element_id_in_tile; + // if (x < width) + // partial_result = reducer(partial_result, input[z][y][z]); + // } + // AtomicReducer(&output[y], partial_result); + // } + // + // Three optimizations are performed. + // + // 1. To coalesc global memory accesses, dilate the tile with a factor of 32 + // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead + // of making each tile consecutive, we let make tile 0 column + // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures + // that threads in a warp access consecutive memory in one iteration (i.e. + // coalesced). In the above example, the warp that contains thread 0-31 + // accesses column 0-31 in the first iteration, and 32-63 in the second + // iteration, and so on. + // + // 2. Partially accumulate partial reduced results computed by threads in the + // same warp using shfl_down. Using shfl_down is faster than directly using + // atomic operations because shfl_down transfers the data between threads + // using shared memory and threads in the same warp run in lock step (thus no + // extra synchronization needed). See + // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ + // for details. The downside is, to produce correct results when using + // shfl_down, we need to guarantee threads in the same warp work on input + // elements with the same y, so the number of tiles in each row must be a + // multiple of 32. + // + // 3. Specialize the case that the entire tile is in bounds. When that is + // true, we don't need to emit "if(x 0; shuffle_distance /= 2) + // partial_result = Reducer( + // partial_result, + // __shfl_down(partial_result, shuffle_distance)); + // if (lane_id == 0) + // AtomicReducer(&output[y], partial_result); + // } + // + // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. + constexpr int64 kTileSize = 8; + // Round the width in tiles up to the nearest multiple of kWarpSize, so that + // the use of shfl_down is valid. + const int64 width_in_tiles = + RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + + auto loop_body_emitter = + [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + // Emit the loop body that reduces one tile. + llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( + input_shape.element_type(), &ir_builder_); + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); + { + TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, + init_value_gen(llvm_ir::IrArray::Index({}))); + ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + } + + // Emit an inner for-loop that partially reduces the elements in the given + // tile. + llvm::Value* z = tile_index[0]; + llvm::Value* y = tile_index[1]; + llvm::Value* x_tile = tile_index[2]; + llvm::Value* warp_id = ir_builder_.CreateUDiv( + x_tile, ir_builder_.getInt64(kWarpSize), "warp_id"); + llvm::Value* lane_id = ir_builder_.CreateURem( + x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); + + // The x-location of the last element in this tile. + // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); + llvm::Value* last_x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(kTileSize - 1), + ir_builder_.CreateNSWMul(warp_id, + ir_builder_.getInt64(kTileSize))))); + + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + std::unique_ptr tile_element_loop = + llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", + ir_builder_.getInt64(0), + ir_builder_.getInt64(kTileSize), + ir_builder_.getInt64(1), &ir_builder_); + + // Emit the body of the partial reduction loop. + llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), + &ir_builder_); + // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + tile_element_loop->GetIndVarValue(), + ir_builder_.CreateNSWMul(warp_id, + ir_builder_.getInt64(kTileSize))))); + + // Unless we know the tile is entirely in bounds, we have to emit a + // x-in-bounds check before reading from the input. + if (!tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), + "x_in_bounds", &ir_builder_); + + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it to the + // partial reduction result. + llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We + // need to convert that to an index to input_shape (the shape of the + // operand of "reduce"). This conversion is composed of a transposition + // from input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = + ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + const std::vector transpose_dimension_mapping( + input_shape.layout().minor_to_major().rbegin(), + input_shape.layout().minor_to_major().rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, &ir_builder_) + .SourceIndexOfTranspose(normalized_input_shape, input_shape, + transpose_dimension_mapping, + &ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, + input_gen(input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + } + return EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, input_address}, + partial_reduction_result_address); + }; + + llvm::Value* tile_in_bounds = ir_builder_.CreateOr( + ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), + ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); + llvm_ir::LlvmIfData if_tile_in_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); + + // After the if-then-else statement on tile_in_bounds, emit calls to + // shfl_down that accumulate the partial reduction results of all threads + // from the warp. + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, + &ir_builder_); + for (int shuffle_distance = 16; shuffle_distance >= 1; + shuffle_distance /= 2) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + partial_reduction_result_address, "partial_reduction_result"); + llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( + element_ir_type, nullptr, "result_from_other_lane"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), &ir_builder_), + result_from_other_lane); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, result_from_other_lane}, + partial_reduction_result_address)); + } + + const HloInstruction* output = + reduce->IsFused() ? reduce->fusion_instruction() : reduce; + + // Emit an atomic operation that accumulates the partial reduction result of + // lane 0 (which holds the partially accumulated result for its warp) to the + // output element. + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + "lane_id_is_zero", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, + &ir_builder_); + llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( + llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_, + "output_element_address"); + return EmitAtomicOperationForNestedComputation( + *reducer, output_address, partial_reduction_result_address); + }; + + // Emit a parallel loop that iterates through every input tiles. + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {depth, height, width_in_tiles}, + {2, 1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); + UpdateLaunchDimensions( + launch_dimensions, + static_cast(LastThunk())->thunks().back().get(), + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &ir_builder_) + .EmitLoop(); +} + +// Figures out whether `reduce` is a row or column reduction, and which +// dimensions to reduce, and calls either `EmitRowReduction` or +// `EmitColumnReduction` as appropriate. +// Prerequisite: the shape of `reduce` has rank 1 and, if `reduce` is fused, the +// fused subgraph is pure elementwise. +Status IrEmitterUnnested::EmitReductionToVector( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer) { + // This emission requires "reduce" to have an input layout. It is either set + // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for + // a fused kReduce). + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << reduce->ToString(); + + // Specialize multi-dimensional-array-to-vector reduction. + // + // TODO(b/33239522): we could use the same algorithm for general reduction + // as long as the input dimensions to keep are adjacent in the layout and + // have the same relative layout as their corresponding output dimensions. + // For example, reducing shape [2,3,4,5] with minor_to_major={2,0,1,3} to + // shape [2,4] with minor_to_major={1,0} can be implemented as a column + // reduction from shape [15,8] to shape [8]. + int64 input_dim_to_keep = -1; + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), + input_dim) == dimensions_to_reduce.end()) { + input_dim_to_keep = input_dim; + break; + } + } + CHECK_NE(-1, input_dim_to_keep); + + if (LayoutUtil::Minor(input_shape.layout(), 0) == input_dim_to_keep) { + // Column reduction. Treat the result of "input" as a matrix whose width + // is the most minor dimension and height the product of other dimensions, + // and treat "reduce" as a column reduction of the input matrix. + const int64 width = ShapeUtil::ElementsIn(reduce->shape()); + // "width" can be zero, so don't do + // height = ShapeUtil::ElementsIn(input_shape) / width; + int64 height = 1; + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (input_dim != input_dim_to_keep) { + height *= input_shape.dimensions(input_dim); + } + } + return EmitColumnReduction(height, width, reduce, input_shape, input_gen, + init_value_gen, reducer); + } else { + // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a + // 3D tensor. The size of dimension 1 (the height) is the size of the + // dimension to keep, the size of dimension 0 (the depth) is the product + // of dimensions that are more major than the dimension to keep, and the + // size of dimension 2 (the width) is the product of more minor + // dimensions. + int64 depth = 1; + int64 width = 1; + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (PositionInContainer( + AsInt64Slice(input_shape.layout().minor_to_major()), input_dim) > + PositionInContainer( + AsInt64Slice(input_shape.layout().minor_to_major()), + input_dim_to_keep)) { + depth *= input_shape.dimensions(input_dim); + } else if (PositionInContainer( + AsInt64Slice(input_shape.layout().minor_to_major()), + input_dim) < + PositionInContainer( + AsInt64Slice(input_shape.layout().minor_to_major()), + input_dim_to_keep)) { + width *= input_shape.dimensions(input_dim); + } + } + int64 height = input_shape.dimensions(input_dim_to_keep); + return EmitRowReduction(depth, height, width, reduce, input_shape, + input_gen, init_value_gen, reducer); + } +} + +Status IrEmitterUnnested::HandleReduce( + HloInstruction* reduce, HloInstruction* input, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer) { + // HandleReduce specializes reduction from a multi-dimensional array to a 1D + // array. The specialized version requires an initializer thunk that + // initializes the output array to the initial value of the reduce. + if (IsReductionToVector(*reduce) && + // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits + 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + std::vector> thunks; + thunks.emplace_back(BuildKernelThunk(reduce)); + TF_RETURN_IF_ERROR(EmitInitializer( + reduce, static_cast(thunks.back().get()))); + bindings_.UnbindAllLocalIrValues(); + thunks.emplace_back(BuildKernelThunk(reduce)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), reduce)); + return EmitReductionToVector( + reduce, input->shape(), + [this, input](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_); + }, + [this, init_value](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value) + .EmitReadArrayElement(index, &ir_builder_); + }, + dimensions_to_reduce, reducer); + } + + thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); + return IrEmitter::HandleReduce(reduce, input, init_value, + dimensions_to_reduce, reducer); +} + +Status IrEmitterUnnested::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + bool all_tuple_elements_have_buffer = std::all_of( + operands.begin(), operands.end(), [this](HloInstruction* tuple_element) { + return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( + tuple_element); + }); + // Tuples (especially output tuples) can take too many tuple elements, + // causing the kernel emitted exceeds the parameter space limit + // (b/31336476). As an optimization, if all tuple elements have a buffer, we + // collect their buffer addresses in a host array, and then copy that array + // to the tuple's buffer. + // + // Some tuple elements (e.g. const or bitcast of const) might not have a + // buffer -- their contents are stored in code. In that case, we fall back + // to emitting kernels which have access to their buffer addresses in code. + if (all_tuple_elements_have_buffer) { + std::vector tuple_element_buffers; + for (const HloInstruction* tuple_element : operands) { + tuple_element_buffers.push_back(GetAllocationIndex(*tuple_element)); + } + thunk_sequence_->emplace_back(MakeUnique( + tuple_element_buffers, GetAllocationIndex(*tuple), tuple)); + return Status::OK(); + } + // If `inst` is a nested thunk that can be disassembled from the result tuple, + // GpuExecutable will disassemble it and return it as part of the resultant + // ShapedBuffer. + if (has_hybrid_result_ && + ReachRootViaOnlyTuples(*tuple, *hlo_computation_->root_instruction())) { + return Status::OK(); + } + thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); + return IrEmitter::HandleTuple(tuple, operands); +} + +Status IrEmitterUnnested::HandleGetTupleElement( + HloInstruction* get_tuple_element, HloInstruction* operand) { + // GetTupleElement IR is emitted in the IR context of the user instruction, + // and so we do not build a kernel for GetTupleElement instructions. + return Status::OK(); +} + +Status IrEmitterUnnested::HandleSelectAndScatter( + HloInstruction* select_and_scatter) { + CHECK_EQ(select_and_scatter->operand_count(), 3); + const auto* operand = select_and_scatter->operand(0); + const auto* source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + PrimitiveType operand_element_type = operand->shape().element_type(); + const int64 rank = ShapeUtil::Rank(operand->shape()); + CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + CHECK_EQ(rank, window.dimensions_size()); + + { + std::vector> thunks; + thunks.emplace_back(BuildKernelThunk(select_and_scatter)); + TF_RETURN_IF_ERROR(EmitInitializer( + select_and_scatter, static_cast(thunks.back().get()))); + bindings_.UnbindAllLocalIrValues(); + thunks.emplace_back(BuildKernelThunk(select_and_scatter)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), select_and_scatter)); + } + + // TODO(b/31410564): Implement dilation rate for select-and-scatter. + if (window_util::HasDilation(window)) { + return Unimplemented( + "Dilation for select-and-scatter not implemented on GPU. " + "See b/31410564."); + } + + // kSelectAndScatter is implemented as two kernel launches: the first launch + // initializes the output array to the given initial value, + // and the second accumulates the "source" matrix to the + // selected elements in the output array. The first launch is already + // implemented by the initializer thunk generated earlier, so this function + // only needs to take care of the select-and-scatter part. + // + // Pseudo code for select-and-scatter: + // + // for (coordinates S in the source): # This loop is parallel. + // initialized_flag = false + // for (coordinates W in the window): + // I = S * stride + W - pad_low + // if I within bounds of operand: + // if !(initialized_flag and select(selected_value, operand(I))): + // selected_value = operand(I) + // selected_index = I + // initialized_flag = true + // output(selected_index) = scatter(output(selected_index), source(S)) + auto loop_body_emitter = + [=](const llvm_ir::IrArray::Index& source_index) -> Status { + // Allocate space to keep the currently selected value, its index, and a + // boolean flag if the value is initialized. The initialized_flag is set + // false. + llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + "selected_value_address", &ir_builder_); + llvm::Value* selected_index_address = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), + "selected_index_address", &ir_builder_); + llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( + ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); + ir_builder_.CreateStore(ir_builder_.getInt1(false), + initialized_flag_address); + + // Create the inner loop to iterate over the window. + llvm_ir::ForLoopNest window_loops(&ir_builder_); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + CHECK_GT(dim.size(), 0); + } + const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + // Compute the operand index to visit and evaluate the condition whether the + // operand index is within the bounds. The unsigned comparison includes + // checking whether the operand index >= 0. + llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); + operand_index[i] = ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, window_index[i]), + ir_builder_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ir_builder_.CreateICmpULT( + operand_index[i], + ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = + ir_builder_.CreateAnd(in_bounds_condition, index_condition); + } + CHECK(in_bounds_condition != nullptr); + + // Only need to do something if the operand index is within the bounds. + // First check if the initialized_flag is set. + llvm_ir::LlvmIfData if_in_bounds = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); + llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( + ir_builder_.CreateLoad(initialized_flag_address), "initialized", + &ir_builder_); + + // If the initialized_flag is false, initialize the selected value and index + // with the currently visiting operand. + llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); + const auto save_operand_index = [&]( + const llvm_ir::IrArray::Index& operand_index) { + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = + ir_builder_.CreateInBoundsGEP(selected_index_address, + {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); + } + }; + llvm_ir::IrArray operand_array(GetIrArray(*operand)); + llvm::Value* operand_data = + operand_array.EmitReadArrayElement(operand_index, &ir_builder_); + ir_builder_.CreateStore(operand_data, selected_value_address); + save_operand_index(operand_index); + ir_builder_.CreateStore(ir_builder_.getInt1(true), + initialized_flag_address); + + // If the initialized_flag is true, call the `select` function to + // potentially update the selected value and index with the currently + // visiting operand. + llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); + const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); + llvm::Value* operand_address = + operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); + llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), + "select_return_buffer", &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *select_and_scatter->select(), + {selected_value_address, operand_address}, select_return_buffer)); + llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer); + + // If the 'select' function returns false, update the selected value and the + // index to the currently visiting operand. + llvm::Value* cond = ir_builder_.CreateICmpNE( + result, llvm::ConstantInt::get( + llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_select_lhs = + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); + ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), + selected_value_address); + save_operand_index(operand_index); + + // After iterating over the window elements, scatter the source element to + // the selected index of the output. The value we store at the output + // location is computed by calling the `scatter` function with the source + // value and the current output value. + llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), + &ir_builder_); + llvm_ir::IrArray::Index selected_index; + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( + selected_index_address, {ir_builder_.getInt32(i)}); + selected_index.push_back( + ir_builder_.CreateLoad(selected_index_address_slot)); + } + llvm::Value* source_value_address = + GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_); + llvm::Value* output_value_address = + GetIrArray(*select_and_scatter) + .EmitArrayElementAddress(selected_index, &ir_builder_); + return EmitAtomicOperationForNestedComputation( + *select_and_scatter->scatter(), output_value_address, + source_value_address); + }; + + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + source->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions( + launch_dimensions, + // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk + // consisting of two thunks, an initializer KernelThunk that initializes + // the output and another KernelThunk that accumulates the scattered + // elements. + static_cast(LastThunk())->thunks().back().get(), + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, source->shape(), + launch_dimensions, &ir_builder_) + .EmitLoop(); +} + +Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while, + HloInstruction* init, + HloComputation* condition, + HloComputation* body) { + TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && + condition->root_instruction()->shape().element_type() == PRED) + << "While condition computation must return bool"; + // Build ForThunk for conformant while loops, otherwise build WhileThunk. + auto result = CanTransformWhileToFor(xla_while); + if (result.ok()) { + auto tuple = result.ConsumeValueOrDie(); + // loop_trip_count = (limit - start + increment - 1) / increment + const int64 loop_trip_count = + (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / + std::get<2>(tuple); + thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + VLOG(3) << "Built ForThunk for while: " << xla_while->name(); + } else { + thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name() + << " while-to-for transform status: " << result.status(); + } + return Status::OK(); +} + +Status IrEmitterUnnested::HandleRng(HloInstruction* random, + RandomDistribution distribution) { + thunk_sequence_->push_back(BuildKernelThunk(random)); + return IrEmitter::HandleRng(random, distribution); +} + +Status IrEmitterUnnested::HandleSelect(HloInstruction* select, + HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + thunk_sequence_->push_back(BuildKernelThunk(select)); + return IrEmitter::HandleSelect(select, pred, on_true, on_false); +} + +llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( + const HloInstruction& hlo, std::vector* io_hlos) { + const BufferAssignment& buffer_assignment = + ir_emitter_context_->buffer_assignment(); + // GetTupleElement instructions are implemented by emitting IR that indexes + // and loads the target tuple element pointer from its operand (possibly + // recursively). For this reason, GetTupleElement instructions are associated + // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. + std::vector non_io_hlos; + for (const HloInstruction* operand : hlo.operands()) { + const HloInstruction* to_lookup = LatestNonGteAncestor(operand); + if (buffer_assignment.HasTopLevelAllocation(to_lookup) && + buffer_assignment.GetUniqueTopLevelAllocation(to_lookup) + .ConsumeValueOrDie() + ->IsInputOrOutput()) { + io_hlos->push_back(operand); + } else { + non_io_hlos.push_back(operand); + } + } + + CHECK_NE(HloOpcode::kGetTupleElement, hlo.opcode()); + if (buffer_assignment.HasTopLevelAllocation(&hlo) && + buffer_assignment.GetUniqueTopLevelAllocation(&hlo) + .ConsumeValueOrDie() + ->IsInputOrOutput()) { + io_hlos->push_back(&hlo); + } else { + non_io_hlos.push_back(&hlo); + } + + llvm::Function* kernel = BuildKernelPrototype(hlo, *io_hlos); + // bindings_ is reused because the bindings of kConstant to their underlying + // llvm::Constant can be shared for all HLOs in this computation. + bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); + return kernel; +} + +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( + const HloInstruction* inst) { + std::vector io_hlos; + llvm::Function* kernel = + EmitBasePointersForHloAndItsOperands(*inst, &io_hlos); + + // Compute the input buffer indices. + std::vector io_buffers; + for (const HloInstruction* io_hlo : io_hlos) { + io_buffers.push_back(GetAllocationIndex(*LatestNonGteAncestor(io_hlo))); + } + + // Create a KernelThunk that launches the kernel that implements "inst". + return MakeUnique(io_buffers, + llvm_ir::AsString(kernel->getName()), inst); +} + +std::unique_ptr IrEmitterUnnested::BuildCopyThunk( + const HloInstruction* inst) { + const HloInstruction* operand = inst->operand(0); + CHECK_EQ(HloOpcode::kConstant, operand->opcode()); + return MakeUnique( + /*source_address=*/LiteralUtil::InternalData(operand->literal()), + /*destination_buffer=*/GetAllocationIndex(*inst), + /*mem_size=*/llvm_ir::ByteSizeOf( + operand->shape(), + ir_emitter_context_->llvm_module()->getDataLayout()), + inst); +} + +std::unique_ptr IrEmitterUnnested::BuildGemmThunk( + const HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kDot) { + const HloInstruction* lhs = inst->operand(0); + const HloInstruction* rhs = inst->operand(1); + return MakeUnique( + GetAllocationIndex(*lhs), // The buffer assigned to LHS. + GetAllocationIndex(*rhs), // The buffer assigned to RHS. + GetAllocationIndex(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + false, // Do not transpose LHS. + false, // Do not transpose RHS. + inst); + } + + if (inst->opcode() == HloOpcode::kFusion) { + const HloInstruction* dot = inst->fused_expression_root(); + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationIndex(*lhs), // The buffer assigned to LHS. + GetAllocationIndex(*rhs), // The buffer assigned to RHS. + GetAllocationIndex(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + dot->operand(0)->IsRank2Transpose(), // Transpose LHS. + dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS. + inst); + } + + LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); +} + +std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( + const HloInstruction* inst) { + const HloInstruction* lhs = inst->operand(0); + const HloInstruction* rhs = inst->operand(1); + if (inst->opcode() == HloOpcode::kConvolution) { + // Forward covolution. + return MakeUnique( + ConvolutionThunk::ConvolutionKind::kForward, + /*input_buffer=*/GetAllocationIndex(*lhs), + /*filter_buffer=*/GetAllocationIndex(*rhs), + /*output_buffer=*/GetAllocationIndex(*inst), + /*input_shape=*/lhs->shape(), + /*filter_shape=*/rhs->shape(), + /*output_shape=*/inst->shape(), inst->window(), + inst->convolution_dimension_numbers(), inst); + } + + // Backward filter convolution, which takes the input (activations) and the + // gradients, and computes the filter. + CHECK_EQ(HloOpcode::kFusion, inst->opcode()); + switch (inst->fusion_kind()) { + case HloInstruction::FusionKind::kConvBackwardFilter: + return MakeUnique( + ConvolutionThunk::ConvolutionKind::kBackwardFilter, + /*input_buffer=*/GetAllocationIndex(*lhs), + /*filter_buffer=*/GetAllocationIndex(*inst), + /*output_buffer=*/GetAllocationIndex(*rhs), + /*input_shape=*/lhs->shape(), + /*filter_shape=*/inst->shape(), + /*output_shape=*/rhs->shape(), inst->window(), + inst->convolution_dimension_numbers(), inst); + case HloInstruction::FusionKind::kConvBackwardInput: + return MakeUnique( + ConvolutionThunk::ConvolutionKind::kBackwardInput, + /*input_buffer=*/GetAllocationIndex(*inst), + /*filter_buffer=*/GetAllocationIndex(*rhs), + /*output_buffer=*/GetAllocationIndex(*lhs), + /*input_shape=*/inst->shape(), + /*filter_shape=*/rhs->shape(), + /*output_shape=*/lhs->shape(), inst->window(), + inst->convolution_dimension_numbers(), inst); + default: + LOG(FATAL) << "Not a convolution-fusion"; + } +} + +Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, + KernelThunk* thunk) { + bool fused = HloOpcode::kFusion == hlo->opcode(); + + const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + CHECK(inst->opcode() == HloOpcode::kSelectAndScatter || + inst->opcode() == HloOpcode::kReduce); + const HloInstruction* init_value = nullptr; + switch (inst->opcode()) { + case HloOpcode::kSelectAndScatter: + init_value = inst->operand(2); + break; + case HloOpcode::kReduce: + init_value = inst->operand(1); + break; + default: + LOG(FATAL) << "Opcode " << inst->opcode() + << " should not need an initializer."; + } + + if (fused && init_value->opcode() == HloOpcode::kParameter) { + init_value = hlo->operand(init_value->parameter_number()); + } + + return EmitTargetElementLoopInThunk( + *hlo, + [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value) + .EmitReadArrayElement(index, &ir_builder_); + }, + thunk); +} + +namespace { + +// Checks that all buffers used during while loop iteration share the same +// buffer allocation. This includes buffers for while result, while init +// operand, condition parameter, body parameter and body result. +// Returns OK on success, error status otherwise. +Status CheckWhileBuffersShareAllocation( + const HloInstruction* xla_while, + const BufferAssignment& buffer_assignment) { + return ShapeUtil::ForEachSubshape( + xla_while->shape(), + [&buffer_assignment, &xla_while](const Shape& /*subshape*/, + const ShapeIndex& index) -> Status { + auto check = [&buffer_assignment](const HloInstruction* a, + const HloInstruction* b, + const ShapeIndex& index) -> Status { + BufferAllocation::Index index_a = + buffer_assignment.GetUniqueAllocation(a, index) + .ConsumeValueOrDie() + ->index(); + BufferAllocation::Index index_b = + buffer_assignment.GetUniqueAllocation(b, index) + .ConsumeValueOrDie() + ->index(); + if (index_a != index_b) { + return InternalError( + "instruction %s does not share allocation with " + "instruction %s ", + a->ToString().c_str(), b->ToString().c_str()); + } + return Status::OK(); + }; + const HloInstruction* condition_parameter = + xla_while->while_condition()->parameter_instruction(0); + const HloComputation* body = xla_while->while_body(); + const HloInstruction* body_parameter = body->parameter_instruction(0); + const HloInstruction* body_result = body->root_instruction(); + TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index)); + TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index)); + TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index)); + TF_RETURN_IF_ERROR(check(xla_while, body_result, index)); + return Status::OK(); + }); +} + +} // namespace + +std::unique_ptr IrEmitterUnnested::BuildWhileThunk( + const HloInstruction* hlo) { + // Check that all while-related buffers share an allocation. + TF_CHECK_OK(CheckWhileBuffersShareAllocation( + hlo, ir_emitter_context_->buffer_assignment())); + + // Generate thunk sequence for while 'condition'. + HloComputation* condition = hlo->while_condition(); + IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, + /*has_hybrid_result=*/false, + ir_emitter_context_); + TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); + + // Generate thunk sequence for while 'body'. + HloComputation* body = hlo->while_body(); + IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, + false /* has_hybrid_result */, + ir_emitter_context_); + TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + + return MakeUnique( + GetAllocationIndex(*condition->root_instruction()), // cond result + ir_emitter_condition.ConsumeThunkSequence(), + ir_emitter_body.ConsumeThunkSequence(), hlo); +} + +std::unique_ptr IrEmitterUnnested::BuildForThunk( + const HloInstruction* hlo, const int64 loop_limit) { + // Check that all while-related buffers share an allocation. + TF_CHECK_OK(CheckWhileBuffersShareAllocation( + hlo, ir_emitter_context_->buffer_assignment())); + + // Generate thunk sequence for while 'body' (will be used a For loop body). + HloComputation* body = hlo->while_body(); + IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, + false /* has_hybrid_result */, + ir_emitter_context_); + TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + + return MakeUnique(loop_limit, + ir_emitter_body.ConsumeThunkSequence(), hlo); +} + +Status IrEmitterUnnested::EmitTargetElementLoopInThunk( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + hlo.shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunk, + ir_emitter_context_->llvm_module()); + // Otherwise, emit a parallel loop that computes the partition that each + // thread is in charge of. + return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + launch_dimensions, &ir_builder_) + .EmitLoop(); +} + +Status IrEmitterUnnested::EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& element_generator) { + CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); + return EmitTargetElementLoopInThunk(hlo, element_generator, + static_cast(LastThunk())); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc new file mode 100644 index 0000000000..14760fe92c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +using Index = BufferAllocation::Index; + +KernelThunk::KernelThunk(tensorflow::gtl::ArraySlice io_buffers, + const string& kernel_name, + const HloInstruction* hlo_instruction) + : Thunk(Kind::kKernel, hlo_instruction), + io_buffers_(io_buffers.begin(), io_buffers.end()), + kernel_name_(kernel_name) {} + +tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { + tensorflow::mutex_lock lock(mutex_); + if (loader_spec_) { + // Already initialized by another thread. + return tensorflow::Status::OK(); + } + + loader_spec_.reset(new se::MultiKernelLoaderSpec(io_buffers_.size() + 1)); + tensorflow::StringPiece ptx = executable.ptx(); + // Convert tensorflow::StringPiece to se::port::StringPiece because + // StreamExecutor uses the latter. + loader_spec_->AddCudaPtxInMemory( + se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + return tensorflow::Status::OK(); +} + +void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { + tensorflow::mutex_lock lock(mutex_); + launch_dimensions_ = launch_dims; +} + +tensorflow::Status KernelThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + // Load the kernel. + se::StreamExecutor* executor = stream->parent(); + se::KernelBase kernel(executor); + LaunchDimensions launch_dimensions; + { + tensorflow::mutex_lock lock(mutex_); + if (!executor->GetKernel(*loader_spec_, &kernel)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } + launch_dimensions = launch_dimensions_; + } + + // Launch the kernel with potentially multiple blocks and threads. + static constexpr int kKernelArgsLimit = 1024; + auto kernel_args = MakeUnique>(); + for (BufferAllocation::Index io_buffer : io_buffers_) { + kernel_args->add_device_memory_argument( + buffer_allocations.GetDeviceAddress(io_buffer)); + } + kernel_args->add_device_memory_argument( + buffer_allocations.GetTempBufferBase()); + if (!stream->parent()->Launch( + stream, se::ThreadDim(launch_dimensions.threads_per_block()), + se::BlockDim(launch_dimensions.block_count()), kernel, + *kernel_args)) { + return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); + } + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h new file mode 100644 index 0000000000..901825873a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_THUNK_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace gpu { + +class GpuExecutable; + +// This class stores everything that StreamExecutor needs for launching a +// kernel. It implements the ExecuteOnStream interface for GpuExecutable to +// invoke the corresponding kernel. +// +// This is thread-compatible. +class KernelThunk : public Thunk { + public: + // Constructs a thunk for the given kernel. + // + // `hlo_instruction` is as in Thunk. Other arguments are as the class members. + KernelThunk(tensorflow::gtl::ArraySlice io_buffers, + const string& kernel_name, const HloInstruction* hlo_instruction); + KernelThunk(const KernelThunk&) = delete; + KernelThunk& operator=(const KernelThunk&) = delete; + ~KernelThunk() override = default; + + const string& kernel_name() const { return kernel_name_; } + void SetLaunchDimensions(const LaunchDimensions& launch_dims); + + tensorflow::Status Initialize(const GpuExecutable& executable) override; + + // Executes the kernel for the thunk on "stream", which must be non-null. + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + // The indices of the input/output buffers. + const std::vector io_buffers_; + + // Entry kernel name for the computation. + const string kernel_name_; + + // The thread and block dimension used to launch the kernel. + // Will be set by IrEmitterUnnested. + LaunchDimensions launch_dimensions_; + + // Describes how to load this kernel. ExecuteOnStream reuses this loader + // specification for all executions. + mutable tensorflow::mutex mutex_; + std::unique_ptr loader_spec_ + GUARDED_BY(mutex_); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc new file mode 100644 index 0000000000..ff6cfd9448 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" + +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +Status GpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + for (auto& instruction : constraints->computation()->instructions()) { + // cuDNN is called with specific layouts on the input, output, and filter: + // + // input: DataLayout::kBatchDepthYX + // output: DataLayout::kBatchDepthYX + // filter: FilterLayout::kOutputInputYX + // + // The order dimensions in the constant name is major-to-minor (eg, the + // most-major dimension of the input is batch, most-minor is X). The + // specific dimension numbers these named dimensions correspond to is + // determined by the ConvolutionDimensionNumbers argument. Y is spatial + // dimension 0, and X is spatial dimension 1. + // + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. + if (ImplementedAsDnnConvolution(*instruction)) { + HloInstruction* input = nullptr; + HloInstruction* filter = nullptr; + HloInstruction* output = nullptr; + if (instruction->opcode() == HloOpcode::kConvolution) { + input = instruction->mutable_operand(0); + filter = instruction->mutable_operand(1); + output = instruction.get(); + } else { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + switch (instruction->fusion_kind()) { + case HloInstruction::FusionKind::kConvBackwardFilter: + // filter = BackwardFilterConvolve(input, output) + input = instruction->mutable_operand(0); + filter = instruction.get(); + output = instruction->mutable_operand(1); + break; + case HloInstruction::FusionKind::kConvBackwardInput: + // input = BackwardInputConvolve(output, filter) + input = instruction.get(); + filter = instruction->mutable_operand(1); + output = instruction->mutable_operand(0); + break; + default: + LOG(FATAL) << "Not a convolution-fusion"; + } + } + + // Construct minor-to-major dimension orders for operands and result. + // cuDNN's convolution APIs support the BDYX layout for activations/output + // and the OIYX layout for weights. + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN + // calls after we switch to cuDNN v5. + const ConvolutionDimensionNumbers& dimension_numbers = + instruction->convolution_dimension_numbers(); + Shape input_shape(input->shape()); + *input_shape.mutable_layout() = + LayoutUtil::MakeLayout({dimension_numbers.spatial_dimensions(1), + dimension_numbers.spatial_dimensions(0), + dimension_numbers.feature_dimension(), + dimension_numbers.batch_dimension()}); + + Shape filter_shape(filter->shape()); + *filter_shape.mutable_layout() = LayoutUtil::MakeLayout( + {dimension_numbers.kernel_spatial_dimensions(1), + dimension_numbers.kernel_spatial_dimensions(0), + dimension_numbers.kernel_input_feature_dimension(), + dimension_numbers.kernel_output_feature_dimension()}); + + Shape output_shape(output->shape()); + *output_shape.mutable_layout() = + LayoutUtil::MakeLayout({dimension_numbers.spatial_dimensions(1), + dimension_numbers.spatial_dimensions(0), + dimension_numbers.feature_dimension(), + dimension_numbers.batch_dimension()}); + + // Set layouts of the instructions' shapes. + if (instruction->opcode() == HloOpcode::kConvolution) { + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(input_shape, output, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(filter_shape, output, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, output)); + } else { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + switch (instruction->fusion_kind()) { + case HloInstruction::FusionKind::kConvBackwardFilter: + // filter = BackwardFilterConvolve(input, output) + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(input_shape, filter, 0)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(filter_shape, filter)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(output_shape, filter, 1)); + break; + case HloInstruction::FusionKind::kConvBackwardInput: + // input = BackwardInputConvolve(output, filter) + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(input_shape, input)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(output_shape, input, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(filter_shape, input, 1)); + break; + default: + LOG(FATAL) << "Not a convolution-fusion"; + } + } + } + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/layout_assignment.h new file mode 100644 index 0000000000..169041eb85 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ + +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace gpu { + +// GPU-specific layout assignment pass which preassigns layouts to satisfy +// layout constraints for operands and results of library calls. +class GpuLayoutAssignment : public LayoutAssignment { + public: + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) + : LayoutAssignment(entry_computation_layout) {} + ~GpuLayoutAssignment() override {} + + protected: + Status AddBackendConstraints(LayoutConstraints* constraints) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc new file mode 100644 index 0000000000..692ec8147d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using LayoutAssignmentTest = HloTestBase; + +TEST_F(LayoutAssignmentTest, Elementwise) { + Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); + Shape ashape_in_row_major(ashape); + Shape ashape_in_col_major(ashape); + *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + // Enumerate all possible combinations of layouts. + for (const Shape& lhs_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + for (const Shape& rhs_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + for (const Shape& result_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + // GpuLayoutAssignment should assign the same layout to "add" and its + // two operands. + auto builder = HloComputation::Builder(TestName()); + auto x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "y")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); + HloModule module(TestName()); + HloComputation* computation = + module.AddEntryComputation(builder.Build(add)); + + ComputationLayout computation_layout( + computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(lhs_shape_with_layout); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(rhs_shape_with_layout); + *computation_layout.mutable_result_layout() = + ShapeLayout(result_shape_with_layout); + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(&module).ValueOrDie()); + + for (const HloInstruction* operand : add->operands()) { + EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(), + operand->shape().layout())); + } + } + } + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD new file mode 100644 index 0000000000..fc0970049a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -0,0 +1,88 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], + features = [ + "-parse_headers", + "no_layering_check", + ], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +cc_library( + name = "llvm_gpu_backend", + srcs = [ + "dump_ir_pass.cc", + "gpu_backend_lib.cc", + "utils.cc", + ], + hdrs = [ + "dump_ir_pass.h", + "gpu_backend_lib.h", + "utils.h", + ], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:gpu_backend_lib_flags", + "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@llvm//:analysis", + "@llvm//:asm_printer", + "@llvm//:bit_reader", + "@llvm//:bit_writer", + "@llvm//:code_gen", + "@llvm//:core", + "@llvm//:instrumentation", + "@llvm//:ipo", + "@llvm//:ir_reader", + "@llvm//:linker", + "@llvm//:mc", + "@llvm//:nvptx_code_gen", + "@llvm//:objc_arc", + "@llvm//:support", + "@llvm//:target", + "@llvm//:transform_utils", + ], +) + +cc_test( + name = "utils_test", + size = "small", + srcs = ["utils_test.cc"], + data = [ + "tests_data/saxpy.ll", + ], + deps = [ + ":llvm_gpu_backend", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm//:core", + "@llvm//:support", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc new file mode 100644 index 0000000000..aeec3a03ca --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" + +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/Support/FileSystem.h" +#include "external/llvm/include/llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +// Pass which dumps the IR of a module into a file. +// +// Because it is implemented as a FunctionPass (IR is dumped +// function-by-function) rather than as a ModulePass the resulting IR is not +// valid (missing metadata, for example) but is still useful for inspection. +// The pass needs to be a FunctionPass rather than a ModulePass because +// inserting ModulePasses is disruptive to LLVM's pass manager. For sequential +// FunctionPasses (also SCC passes, etc) the pass manager executes the passes +// sequentially on each function (SCC, etc). Inserting a ModulePass between +// FunctionPasses acts as a barrier forcing the FunctionPasses to execute fully +// across all functions prior to advancing to the next pass. For some reason +// this results in different generated code resulting in an undesirable +// Heisenberg effect when dumping the IR. +class DumpIrPass : public llvm::FunctionPass { + public: + explicit DumpIrPass(const string &output_filename) + : llvm::FunctionPass(id_), output_filename_(output_filename) {} + + bool doInitialization(llvm::Module &M) override { + out_.reset(new llvm::raw_fd_ostream(llvm::StringRef(output_filename_), ec_, + llvm::sys::fs::F_None)); + if (ec_) { + LOG(FATAL) << "Unable to open " << output_filename_ + << " to dump LLVM IR: " << ec_.message(); + } + return false; + } + + bool runOnFunction(llvm::Function &Function) override { + Function.print(*out_); + return false; + } + + void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + + bool doFinalization(llvm::Module &M) override { + out_->close(); + return false; + } + + private: + static char id_; + string output_filename_; + std::error_code ec_; + std::unique_ptr out_; +}; + +char DumpIrPass::id_ = 0; + +void IrDumpingPassManager::run(llvm::Module &module) { + for (int i = 0; i < passes_.size(); ++i) { + llvm::Pass *P = passes_[i]; + if (dump_ir_) { + const llvm::PassInfo *PI = + llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); + const string basename = ReplaceFilenameExtension( + tensorflow::io::Basename(input_filename_), + tensorflow::strings::Printf( + "pass-%02d.before.%s.ll", i, + (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); + llvm::legacy::PassManager::add( + new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename))); + } + llvm::legacy::PassManager::add(P); + } + passes_.clear(); + llvm::legacy::PassManager::run(module); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h new file mode 100644 index 0000000000..1d515a0f28 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_DUMP_IR_PASS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_DUMP_IR_PASS_H_ + +#include + +#include "external/llvm/include/llvm/IR/LegacyPassManager.h" +#include "external/llvm/include/llvm/Pass.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +// Pass manager which optionally dumps the IR to a sequence of files before each +// pass. +class IrDumpingPassManager : public llvm::legacy::PassManager { + public: + IrDumpingPassManager(const string& input_filename, const string& output_dir, + bool dump_ir) + : llvm::legacy::PassManager(), + input_filename_(input_filename), + output_dir_(output_dir), + dump_ir_(dump_ir) {} + void add(llvm::Pass* P) { passes_.push_back(P); } + void run(llvm::Module& module); // NOLINT(runtime/references) + + private: + string input_filename_; + string output_dir_; + bool dump_ir_; + std::vector passes_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_DUMP_IR_PASS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc new file mode 100644 index 0000000000..e15659938c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -0,0 +1,489 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/util.h" + +#include "external/llvm/include/llvm/ADT/STLExtras.h" +#include "external/llvm/include/llvm/ADT/StringMap.h" +#include "external/llvm/include/llvm/ADT/StringSet.h" +#include "external/llvm/include/llvm/Analysis/TargetLibraryInfo.h" +#include "external/llvm/include/llvm/Analysis/TargetTransformInfo.h" +#include "external/llvm/include/llvm/Bitcode/BitcodeReader.h" +#include "external/llvm/include/llvm/Bitcode/BitcodeWriter.h" +#include "external/llvm/include/llvm/CodeGen/CommandFlags.h" +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/IR/LegacyPassManager.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/LinkAllIR.h" +#include "external/llvm/include/llvm/LinkAllPasses.h" +#include "external/llvm/include/llvm/Linker/Linker.h" +#include "external/llvm/include/llvm/PassRegistry.h" +#include "external/llvm/include/llvm/Support/CommandLine.h" +#include "external/llvm/include/llvm/Support/FileSystem.h" +#include "external/llvm/include/llvm/Support/FormattedStream.h" +#include "external/llvm/include/llvm/Support/TargetRegistry.h" +#include "external/llvm/include/llvm/Support/TargetSelect.h" +#include "external/llvm/include/llvm/Support/ToolOutputFile.h" +#include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "external/llvm/include/llvm/Transforms/IPO.h" +#include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h" +#include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { +namespace { + +// Default inline threshold value to use in llvm. +const int kDefaultInlineThreshold = 1100; + +// Information about a GPU architecture for the backend. +struct GpuBackendInfo { + string libdevice_name; + string sm_name; +}; + +// Maps supported CUDA compute capability to a libdevice file to link for this +// capability. +std::map gpu_info_map = { + {"compute_20", {"libdevice.compute_20.10.bc", "sm_20"}}, + {"compute_30", {"libdevice.compute_30.10.bc", "sm_30"}}, + {"compute_35", {"libdevice.compute_35.10.bc", "sm_35"}}, + + // NVIDIA does not provide a separate libdevice for CC 3.7, but we can use + // the one for 3.5. + {"compute_37", {"libdevice.compute_35.10.bc", "sm_37"}}, +}; + +// Validate the --gpu_architecture command-line flag. +static void ValidateGPUArchitecture(const string& value) { + if (!gpu_info_map.count(value)) { + LOG(FATAL) << "value for --gpu_architecture must be compute_{20,30,35,37}"; + } +} + +// Convenience function for producing a name of a temporary compilation product +// from the input filename. +string MakeNameForTempProduct(const std::string& input_filename, + tensorflow::StringPiece extension) { + legacy_flags::GpuBackendLibFlags* flags = + legacy_flags::GetGpuBackendLibFlags(); + return tensorflow::io::JoinPath( + flags->dump_temp_products_to, + ReplaceFilenameExtension( + tensorflow::io::Basename(llvm_ir::AsString(input_filename)), + extension)); +} + +// Initializes LLVM passes. Uses the PassRegistry mechanism. +void InitializePasses(llvm::PassRegistry* pass_registry) { + llvm::initializeCore(*pass_registry); + llvm::initializeCodeGen(*pass_registry); + llvm::initializeScalarOpts(*pass_registry); + llvm::initializeObjCARCOpts(*pass_registry); + llvm::initializeVectorization(*pass_registry); + llvm::initializeIPO(*pass_registry); + llvm::initializeAnalysis(*pass_registry); + llvm::initializeTransformUtils(*pass_registry); + llvm::initializeInstCombine(*pass_registry); + llvm::initializeInstrumentation(*pass_registry); + llvm::initializeTarget(*pass_registry); + llvm::initializeCodeGenPreparePass(*pass_registry); +} + +// Returns the TargetMachine, given a triple. +std::unique_ptr GetTargetMachine( + llvm::Triple triple, tensorflow::StringPiece cpu_name) { + std::string error; + const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); + if (target == nullptr) { + LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'" + << " -- " << error; + return nullptr; + } + + TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); + // Enable FMA synthesis if desired. + legacy_flags::GpuBackendLibFlags* flags = + legacy_flags::GetGpuBackendLibFlags(); + if (flags->fma) { + target_options.AllowFPOpFusion = FPOpFusion::Fast; + } + + // Set options from LlvmBackendFlags (specifically, fast-math flags). + llvm_ir::SetTargetOptions(&target_options); + + // Set the verbose assembly options. + target_options.MCOptions.AsmVerbose = flags->verbose_ptx_asm; + + // The selection of codegen optimization level is copied from function + // GetCodeGenOptLevel in //external/llvm/tools/opt/opt.cpp. + CodeGenOpt::Level codegen_opt_level; + switch (flags->opt_level) { + case 1: + codegen_opt_level = CodeGenOpt::Less; + break; + case 2: + codegen_opt_level = CodeGenOpt::Default; + break; + case 3: + codegen_opt_level = CodeGenOpt::Aggressive; + break; + default: + codegen_opt_level = CodeGenOpt::None; + } + return WrapUnique(target->createTargetMachine( + triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx42", target_options, + Optional(RelocModel), CMModel, codegen_opt_level)); +} + +// Adds the standard LLVM optimization passes, based on the speed optimization +// level (opt_level) and size optimization level (size_level). Both module +// and function-level passes are added, so two pass managers are passed in and +// modified by this function. +void AddOptimizationPasses(unsigned opt_level, unsigned size_level, + llvm::TargetMachine* target_machine, + llvm::legacy::PassManagerBase* module_passes, + llvm::legacy::FunctionPassManager* function_passes) { + PassManagerBuilder builder; + builder.OptLevel = opt_level; + builder.SizeLevel = size_level; + + if (opt_level > 1) { + builder.Inliner = llvm::createFunctionInliningPass(kDefaultInlineThreshold); + } else { + // Only inline functions marked with "alwaysinline". + builder.Inliner = llvm::createAlwaysInlinerLegacyPass(); + } + + builder.DisableUnitAtATime = false; + builder.DisableUnrollLoops = opt_level == 0; + builder.LoopVectorize = opt_level > 0; + builder.SLPVectorize = opt_level > 1 && size_level < 2; + + // NVPTX's early-as-possible passes include NVVM reflect. + builder.addExtension( + llvm::PassManagerBuilder::EP_EarlyAsPossible, + [&](const PassManagerBuilder&, legacy::PassManagerBase& pass_manager) { + target_machine->addEarlyAsPossiblePasses(pass_manager); + }); + + builder.populateFunctionPassManager(*function_passes); + builder.populateModulePassManager(*module_passes); +} + +// Emits the given module to a bit code file. +void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { + std::error_code error_code; + llvm::tool_output_file outfile(filename.ToString().c_str(), error_code, + llvm::sys::fs::F_None); + if (error_code) { + LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); + } + + llvm::WriteBitcodeToFile(&module, outfile.os()); + outfile.keep(); +} + +// Emits the given module to PTX. target_machine is an initialized TargetMachine +// for the NVPTX target. +string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { + std::string ptx; // need a std::string instead of a ::string. + { + llvm::raw_string_ostream stream(ptx); + llvm::buffer_ostream pstream(stream); + // The extension is stripped by IrDumpingPassManager, so we need to + // get creative to add a suffix. + string module_id(llvm_ir::AsString(module->getModuleIdentifier())); + legacy_flags::GpuBackendLibFlags* flags = + legacy_flags::GetGpuBackendLibFlags(); + IrDumpingPassManager codegen_passes( + ReplaceFilenameExtension(tensorflow::io::Basename(module_id), + "-nvptx.dummy"), + flags->dump_temp_products_to, flags->dump_ir_before_passes); + codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(module->getTargetTriple()))); + + target_machine->addPassesToEmitFile(codegen_passes, pstream, + llvm::TargetMachine::CGFT_AssemblyFile); + codegen_passes.run(*module); + } + + return ptx; +} + +// LLVM has an extensive flags mechanism of its own, which is only accessible +// through the command line. Internal libraries within LLVM register parsers for +// flags, with no other way to configure them except pass these flags. +// To do this programmatically, we invoke ParseCommandLineOptions manually with +// a "fake argv". +// Note: setting flags with this method is stateful, since flags are just +// static globals within LLVM libraries. +void FeedLLVMWithFlags(const std::vector& cl_opts) { + std::vector fake_argv = {""}; + for (const string& cl_opt : cl_opts) { + fake_argv.push_back(cl_opt.c_str()); + } + llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); +} + +namespace { +// Returns whether the module could use any libdevice functions. This function +// may have false positives -- the module might not use libdevice even if this +// function returns true. +bool CouldNeedLibdevice(const llvm::Module& module) { + for (const llvm::Function& function : module.functions()) { + // This is a conservative approximation -- not all such functions are in + // libdevice. + if (!function.isIntrinsic() && function.isDeclaration()) { + return true; + } + } + return false; +} + +// Links libdevice into the given module if the module needs libdevice. +tensorflow::Status LinkLibdeviceIfNecessary(const string& libdevice_dir_path, + llvm::Module* module) { + if (!CouldNeedLibdevice(*module)) { + return tensorflow::Status::OK(); + } + + llvm::Linker linker(*module); + legacy_flags::GpuBackendLibFlags* flags = + legacy_flags::GetGpuBackendLibFlags(); + ValidateGPUArchitecture(flags->gpu_architecture); + string libdevice_bc_filename = + gpu_info_map[flags->gpu_architecture].libdevice_name; + string libdevice_bc_fullpath = + tensorflow::io::JoinPath(libdevice_dir_path, libdevice_bc_filename); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->FileExists(libdevice_bc_fullpath)); + std::unique_ptr libdevice_module = + LoadIRModule(libdevice_bc_fullpath, &module->getContext()); + VLOG(1) << "Linking with libdevice from: " << libdevice_bc_fullpath; + if (linker.linkInModule(std::move(libdevice_module), + llvm::Linker::Flags::InternalizeLinkedSymbols | + llvm::Linker::Flags::LinkOnlyNeeded)) { + LOG(FATAL) << "Error linking libdevice from " << libdevice_bc_fullpath; + } + return tensorflow::Status::OK(); +} + +} // namespace + +StatusOr CompileModuleToPtx(llvm::Module* module, + const string& libdevice_dir_path) { + // Link the input module with libdevice, to pull in implementations of some + // builtins. + TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(libdevice_dir_path, module)); + + legacy_flags::GpuBackendLibFlags* flags = + legacy_flags::GetGpuBackendLibFlags(); + if (!flags->dump_temp_products_to.empty()) { + string linked_filename = + MakeNameForTempProduct(module->getModuleIdentifier(), "linked.bc"); + LOG(INFO) << "dumping bitcode after linking libdevice to: " + << linked_filename; + EmitBitcodeToFile(*module, linked_filename); + } + + // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass + // can access it. + module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", flags->ftz); + + // If ftz is enabled, set it as an attribute on every function in the module. + if (flags->ftz) { + for (llvm::Function& fn : *module) { + fn.addFnAttr("nvptx-f32ftz", "true"); + } + } + + // Run IR-level optimizations. + if (flags->dump_ir_before_passes && flags->dump_temp_products_to.empty()) { + LOG(FATAL) << "--dump_ir_before_passes must be specified with " + "--dump_temp_products_to"; + } + + IrDumpingPassManager module_passes(module->getModuleIdentifier(), + flags->dump_temp_products_to, + flags->dump_ir_before_passes); + + // Add an appropriate TargetLibraryInfo pass for the module's triple. + llvm::TargetLibraryInfoWrapperPass* tliwp = + new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(module->getTargetTriple())); + module_passes.add(tliwp); + + // Try to fetch the target triple from the module. If not present, set a + // default target triple. + llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); + if (target_triple.getArch() == llvm::Triple::UnknownArch) { + LOG(WARNING) << "target triple not found in the module"; + target_triple = llvm::Triple("nvptx64-unknown-unknown"); + } + + // Figure out the exact name of the processor as known to the NVPTX backend + // from the gpu_architecture flag. + ValidateGPUArchitecture(flags->gpu_architecture); + string cpu_name = gpu_info_map[flags->gpu_architecture].sm_name; + + std::unique_ptr target_machine = + GetTargetMachine(target_triple, cpu_name); + module_passes.add(llvm::createTargetTransformInfoWrapperPass( + target_machine->getTargetIRAnalysis())); + + // The LLVM IR verifier performs sanity checking on the IR. This helps + // discover problems and report them in a meaningful manner, rather than let + // later passes report obscure assertions becasue of unfulfilled invariants. + module_passes.add(llvm::createVerifierPass()); + + // Create the function-level pass manager. It needs data layout information + // too. + llvm::legacy::FunctionPassManager function_passes(module); + + AddOptimizationPasses(flags->opt_level, /*size_level=*/0, + target_machine.get(), &module_passes, &function_passes); + // Loop unrolling exposes more opportunites for SROA. Therefore, we run SROA + // again after the standard optimization passes [http://b/13329423]. + // TODO(jingyue): SROA may further expose more optimization opportunites, such + // as more precise alias analysis and more function inlining (SROA may change + // the inlining cost of a function). For now, running SROA already emits good + // enough code for the evaluated benchmarks. We may want to run more + // optimizations later. + if (flags->opt_level > 0) { + // LLVM's optimizer turns on SROA when the optimization level is greater + // than 0. We mimic this behavior here. + module_passes.add(llvm::createSROAPass()); + } + + // Verify that the module is well formed after optimizations ran. + module_passes.add(llvm::createVerifierPass()); + + // Done populating the pass managers. Now run them. + + function_passes.doInitialization(); + for (auto func = module->begin(); func != module->end(); ++func) { + function_passes.run(*func); + } + function_passes.doFinalization(); + module_passes.run(*module); + + if (!flags->dump_temp_products_to.empty()) { + string optimized_filename = + MakeNameForTempProduct(module->getModuleIdentifier(), "optimized.bc"); + LOG(INFO) << "dumping bitcode after optimizations to: " + << optimized_filename; + EmitBitcodeToFile(*module, optimized_filename); + } + + // Finally, produce PTX. + return EmitModuleToPTX(module, target_machine.get()); +} + +// One-time module initializer. +// Must be called only once -- DO NOT CALL DIRECTLY. +void GPUBackendInit() { + // Feed all customized flags here, so we can override them with llvm_cl_opts + // without redeploy the compiler for development purpose. + + // This flag tunes a threshold in branch folding. The default threshold, which + // is one, is not suitable for CUDA programs where branches are more expensive + // than for CPU programs. Setting the threshold to 2 improves the latency of + // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the + // latency of other benchmarks so far. + // + // I also tried setting this threshold to other values: + // * 3-6 gives similar results as 2; + // * >6 start hurting the performance of at least dot product kernels. + // + // TODO(jingyue): The current threshold only considers the numbr of IR + // instructions which do not accurately reflect the true cost. We need a + // better cost model. + FeedLLVMWithFlags({"-bonus-inst-threshold=2"}); + // TODO(b/22073864): Increase limit when scan memory dependency. + // This helps to reduce more redundant load instructions. + // + // The specific value is currently large enough for s3d in shoc benchmark, + // which contains a lot of load instructions and many arithmetic instructions + // between those loads. + FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + + legacy_flags::GpuBackendLibFlags* flags = + legacy_flags::GetGpuBackendLibFlags(); + if (!flags->llvm_cl_opts.empty()) { + std::vector opts = + tensorflow::str_util::Split(flags->llvm_cl_opts, ','); + FeedLLVMWithFlags(opts); + } + + if (flags->llvm_dump_passes) { + // Enable LLVM pass debugging dump. LLVM dumps this information when a pass + // manager is initialized for execution. It's done to stderr (this is + // hardcoded within LLVM to the dbgs() stream, we can't change it from the + // outside). + FeedLLVMWithFlags({"-debug-pass=Arguments"}); + } + + // Initialize the NVPTX target; it's the only target we link with, so call its + // specific initialization functions instead of the catch-all InitializeAll*. + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + + // Initialize the LLVM optimization passes. + llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); + InitializePasses(registry); +} + +} // namespace + +StatusOr CompileToPtx(llvm::Module* module, + const string& libdevice_dir_path) { + static std::once_flag backend_init_flag; + std::call_once(backend_init_flag, GPUBackendInit); + + string ptx; + { + ScopedLoggingTimer compilation_timer( + "Compile module " + llvm_ir::AsString(module->getName()), + /*vlog_level=*/2); + TF_ASSIGN_OR_RETURN(ptx, CompileModuleToPtx(module, libdevice_dir_path)); + } + return ptx; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h new file mode 100644 index 0000000000..3413f33301 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LLVM-based compiler backend. +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ + +#include + +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { +namespace gpu { + +// The Compile.* interfaces each create their own llvm::LLVMContext objects for +// thread safety, but note that LLVM's multithreaded support is very +// preliminary; multithreaded use is not recommended at this time. +// +// Compiles the argument module and returns it. libdevice_dir_path is the parent +// directory of the libdevice bitcode libraries. The contents of the module may +// be changed. +StatusOr CompileToPtx(llvm::Module* module, + const string& libdevice_dir_path); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll new file mode 100644 index 0000000000..2ae1d2f7ea --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll @@ -0,0 +1,141 @@ +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" +target triple = "nvptx64-unknown-unknown" + +%struct.uint3 = type { i32, i32, i32 } +%struct.dim3 = type { i32, i32, i32 } + +@blockIdx = external addrspace(1) global %struct.uint3 +@blockDim = external addrspace(1) global %struct.dim3 +@threadIdx = external addrspace(1) global %struct.uint3 + +; Function Attrs: alwaysinline nounwind readnone +define float @expf(float %f) #0 { +entry: + %f.addr = alloca float, align 4 + store float %f, float* %f.addr, align 4 + %0 = load float, float* %f.addr, align 4 + %call = call float @__nv_expf(float %0) + ret float %call +} + +declare float @__nv_expf(float) #1 + +; Function Attrs: nounwind +define void @cuda_saxpy(i32* %n, float* %a, float* %x, float* %y) #2 { +entry: + %n.addr = alloca i32*, align 8 + %a.addr = alloca float*, align 8 + %x.addr = alloca float*, align 8 + %y.addr = alloca float*, align 8 + %i = alloca i32, align 4 + store i32* %n, i32** %n.addr, align 8 + store float* %a, float** %a.addr, align 8 + store float* %x, float** %x.addr, align 8 + store float* %y, float** %y.addr, align 8 + %0 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @blockIdx to %struct.uint3*), i32 0, i32 0), align 4 + %1 = load i32, i32* getelementptr inbounds (%struct.dim3, %struct.dim3* addrspacecast (%struct.dim3 addrspace(1)* @blockDim to %struct.dim3*), i32 0, i32 0), align 4 + %mul = mul i32 %0, %1 + %2 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @threadIdx to %struct.uint3*), i32 0, i32 0), align 4 + %add = add i32 %mul, %2 + store i32 %add, i32* %i, align 4 + %3 = load i32, i32* %i, align 4 + %4 = load i32*, i32** %n.addr, align 8 + %arrayidx = getelementptr inbounds i32, i32* %4, i64 0 + %5 = load i32, i32* %arrayidx, align 4 + %cmp = icmp slt i32 %3, %5 + br i1 %cmp, label %if.then, label %if.end + +if.then: ; preds = %entry + %6 = load float*, float** %a.addr, align 8 + %arrayidx1 = getelementptr inbounds float, float* %6, i64 0 + %7 = load float, float* %arrayidx1, align 4 + %8 = load i32, i32* %i, align 4 + %idxprom = sext i32 %8 to i64 + %9 = load float*, float** %x.addr, align 8 + %arrayidx2 = getelementptr inbounds float, float* %9, i64 %idxprom + %10 = load float, float* %arrayidx2, align 4 + %mul3 = fmul float %7, %10 + %11 = load i32, i32* %i, align 4 + %idxprom4 = sext i32 %11 to i64 + %12 = load float*, float** %y.addr, align 8 + %arrayidx5 = getelementptr inbounds float, float* %12, i64 %idxprom4 + %13 = load float, float* %arrayidx5, align 4 + %add6 = fadd float %mul3, %13 + %14 = load i32, i32* %i, align 4 + %idxprom7 = sext i32 %14 to i64 + %15 = load float*, float** %y.addr, align 8 + %arrayidx8 = getelementptr inbounds float, float* %15, i64 %idxprom7 + store float %add6, float* %arrayidx8, align 4 + br label %if.end + +if.end: ; preds = %if.then, %entry + ret void +} + +; Function Attrs: nounwind +define void @cuda_saxpy_s(i32* %n, float* %a, float* %x, float* %y) #2 { +entry: + %n.addr = alloca i32*, align 8 + %a.addr = alloca float*, align 8 + %x.addr = alloca float*, align 8 + %y.addr = alloca float*, align 8 + %i = alloca i32, align 4 + store i32* %n, i32** %n.addr, align 8 + store float* %a, float** %a.addr, align 8 + store float* %x, float** %x.addr, align 8 + store float* %y, float** %y.addr, align 8 + %0 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @blockIdx to %struct.uint3*), i32 0, i32 0), align 4 + %1 = load i32, i32* getelementptr inbounds (%struct.dim3, %struct.dim3* addrspacecast (%struct.dim3 addrspace(1)* @blockDim to %struct.dim3*), i32 0, i32 0), align 4 + %mul = mul i32 %0, %1 + %2 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @threadIdx to %struct.uint3*), i32 0, i32 0), align 4 + %add = add i32 %mul, %2 + store i32 %add, i32* %i, align 4 + call void @llvm.cuda.syncthreads() + %3 = load i32, i32* %i, align 4 + %4 = load i32*, i32** %n.addr, align 8 + %arrayidx = getelementptr inbounds i32, i32* %4, i64 0 + %5 = load i32, i32* %arrayidx, align 4 + %cmp = icmp slt i32 %3, %5 + br i1 %cmp, label %if.then, label %if.end + +if.then: ; preds = %entry + %6 = load float*, float** %a.addr, align 8 + %arrayidx1 = getelementptr inbounds float, float* %6, i64 0 + %7 = load float, float* %arrayidx1, align 4 + %8 = load i32, i32* %i, align 4 + %idxprom = sext i32 %8 to i64 + %9 = load float*, float** %x.addr, align 8 + %arrayidx2 = getelementptr inbounds float, float* %9, i64 %idxprom + %10 = load float, float* %arrayidx2, align 4 + %mul3 = fmul float %7, %10 + %11 = load i32, i32* %i, align 4 + %idxprom4 = sext i32 %11 to i64 + %12 = load float*, float** %y.addr, align 8 + %arrayidx5 = getelementptr inbounds float, float* %12, i64 %idxprom4 + %13 = load float, float* %arrayidx5, align 4 + %add6 = fadd float %mul3, %13 + %14 = load i32, i32* %i, align 4 + %idxprom7 = sext i32 %14 to i64 + %15 = load float*, float** %y.addr, align 8 + %arrayidx8 = getelementptr inbounds float, float* %15, i64 %idxprom7 + store float %add6, float* %arrayidx8, align 4 + br label %if.end + +if.end: ; preds = %if.then, %entry + ret void +} + +; Function Attrs: nounwind +declare void @llvm.cuda.syncthreads() #3 + +attributes #0 = { alwaysinline nounwind readnone "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } + +!nvvm.annotations = !{!0, !1} +!llvm.ident = !{!2} + +!0 = !{void (i32*, float*, float*, float*)* @cuda_saxpy, !"kernel", i32 1} +!1 = !{void (i32*, float*, float*, float*)* @cuda_saxpy_s, !"kernel", i32 1} +!2 = !{!"clang version xla-trunk (trunk r203011)"} diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc new file mode 100644 index 0000000000..c10346bbc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" + +#include "tensorflow/core/platform/logging.h" + +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IRReader/IRReader.h" +#include "external/llvm/include/llvm/Support/SourceMgr.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace { + +static void DieWithSMDiagnosticError(llvm::SMDiagnostic* diagnostic) { + LOG(FATAL) << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() + << ": " << diagnostic->getMessage().str(); +} + +} // namespace + +namespace xla { +namespace gpu { + +std::unique_ptr LoadIRModule(const string& filename, + llvm::LLVMContext* llvm_context) { + llvm::SMDiagnostic diagnostic_err; + std::unique_ptr module( + llvm::parseIRFile(llvm::StringRef(filename.data(), filename.size()), + diagnostic_err, *llvm_context)); + + if (module == nullptr) { + DieWithSMDiagnosticError(&diagnostic_err); + } + + return module; +} + +string ReplaceFilenameExtension(tensorflow::StringPiece filename, + tensorflow::StringPiece new_extension) { + auto pos = filename.rfind('.'); + tensorflow::StringPiece stem = + pos == tensorflow::StringPiece::npos + ? filename + : tensorflow::StringPiece(filename.data(), pos); + return tensorflow::strings::StrCat(stem, ".", new_extension); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h new file mode 100644 index 0000000000..a6daeca95a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_ + +#include +#include +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace llvm { +class LLVMContext; +class Module; +} + +namespace xla { +namespace gpu { + +// Convenience function for loading a LLVM module from an IR file. The module +// is created in the given LLVM context. +// +// If loading fails for some reason, dies printing a diagnostic error. +std::unique_ptr LoadIRModule(const string& filename, + llvm::LLVMContext* llvm_context); + +// Convenience function for replacing the extension of the given filename. +// If the filename has no extension, the new extension is appended to its name. +// +// For example: +// ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" +string ReplaceFilenameExtension(tensorflow::StringPiece filename, + tensorflow::StringPiece new_extension); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc new file mode 100644 index 0000000000..3848e58b0d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" + +#include + +#include "tensorflow/core/lib/io/path.h" + +#include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +const char kSaxpyIRFile[] = + "compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll"; + +TEST(UtilsTest, TestLoadIRModule) { + llvm::LLVMContext llvm_context; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + std::unique_ptr module = LoadIRModule( + tensorflow::io::JoinPath(test_srcdir, kSaxpyIRFile), &llvm_context); + // Sanity check that the module was loaded properly. + ASSERT_NE(nullptr, module); + ASSERT_NE(std::string::npos, module->getModuleIdentifier().find("saxpy.ll")); + ASSERT_NE(nullptr, module->getFunction("cuda_saxpy")); +} + +TEST(UtilsTest, TestReplaceFilenameExtension) { + ASSERT_EQ(ReplaceFilenameExtension("baz.tx", "cc"), "baz.cc"); + ASSERT_EQ(ReplaceFilenameExtension("/foo/baz.txt", "cc"), "/foo/baz.cc"); + ASSERT_EQ(ReplaceFilenameExtension("/foo/baz.", "-nvptx.dummy"), + "/foo/baz.-nvptx.dummy"); + ASSERT_EQ(ReplaceFilenameExtension("/foo/baz", "cc"), "/foo/baz.cc"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc new file mode 100644 index 0000000000..ca70d55fab --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -0,0 +1,408 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +namespace { +bool IsForwardConvolutionCanonical(const HloInstruction& conv) { + CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); + return window_util::HasEvenPadding(conv.window()) && + !window_util::HasNegativePadding(conv.window()) && + !window_util::HasDilation(conv.window()); +} + +// If the (positive and negative) padding on the input operand of a convolution +// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and +// dilation), returns kPad and/or kSlice instructions that explicitly apply the +// padding; otherwise returns the original input operand. When there is both +// positive padding (including dilation) and negative padding, we insert both +// kPad and kSlice. +HloInstruction* MaybePaddedAndSlicedInput( + const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, + HloInstruction* input) { + HloComputation* computation = input->parent(); + if (!window_util::HasEvenPadding(conv_window) || + window_util::HasBaseDilation(conv_window)) { + // If padding is uneven or has dilation, we insert a kPad instruction that + // applies positive padding and dilation. + PaddingConfig padding_config = + MakeNoPaddingConfig(input->shape().dimensions_size()); + for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.spatial_dimensions(i); + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + std::max(0LL, conv_window.dimensions(i).padding_low())); + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + std::max(0LL, conv_window.dimensions(i).padding_high())); + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window.dimensions(i).base_dilation() - 1); + } + PrimitiveType element_type = input->shape().element_type(); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(LiteralUtil::Zero(element_type)))); + input = computation->AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape( + /*operand_shape=*/input->shape(), + /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), + padding_config) + .ConsumeValueOrDie(), + input, padding, padding_config)); + } + + if (window_util::HasNegativePadding(conv_window)) { + // If the window has negative padding, insert a kSlice that explicitly + // applies negative padding. + // + // For each dimension, initialize the start index to 0 and the limit index + // to the size of that dimension. + std::vector start_indices(input->shape().dimensions_size(), 0); + std::vector limit_indices(input->shape().dimensions().begin(), + input->shape().dimensions().end()); + for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.spatial_dimensions(i); + // If dimension "dim" has negative padding, increase the start index or + // decrement the limit index by the amount of negative padding. + start_indices[dim] += + std::max(0LL, -conv_window.dimensions(i).padding_low()); + limit_indices[dim] -= + std::max(0LL, -conv_window.dimensions(i).padding_high()); + } + + input = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeInference::InferSliceShape(input->shape(), start_indices, + limit_indices) + .ConsumeValueOrDie(), + input, start_indices, limit_indices)); + } + + return input; +} + +// If the padding on the kernel operand of a convolution can't be folded into a +// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that +// explicitly applies the padding; otherwise returns the original kernel +// operand. +HloInstruction* MaybePaddedKernel(const Window& conv_window, + const ConvolutionDimensionNumbers& conv_dnums, + HloInstruction* kernel) { + if (!window_util::HasWindowDilation(conv_window)) { + return kernel; + } + + // Compute the shape and padding config of the pad to be inserted. + PaddingConfig padding_config; + for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { + padding_config.add_dimensions(); + } + for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.spatial_dimensions(i); + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window.dimensions(i).window_dilation() - 1); + } + + HloComputation* computation = kernel->parent(); + PrimitiveType element_type = kernel->shape().element_type(); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(LiteralUtil::Zero(element_type)))); + return computation->AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape( + /*operand_shape=*/kernel->shape(), + /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), + padding_config) + .ConsumeValueOrDie(), + kernel, padding, padding_config)); +} +} // namespace + +bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { + if (IsForwardConvolutionCanonical(*conv)) { + return false; + } + + // Insert slices and/or pads between the convolution and its input and/or + // kernel operand. + HloInstruction* new_input = MaybePaddedAndSlicedInput( + conv->window(), conv->convolution_dimension_numbers(), + conv->mutable_operand(0)); + HloInstruction* new_kernel = + MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), + conv->mutable_operand(1)); + + // Remove the padding from convolution's window field. These paddings are + // made explicit with the inserted pads. + Window new_conv_window = conv->window(); + for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { + WindowDimension* dim = new_conv_window.mutable_dimensions(i); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + } + conv->parent()->ReplaceWithNewInstruction( + conv, HloInstruction::CreateConvolve( + conv->shape(), new_input, new_kernel, new_conv_window, + conv->convolution_dimension_numbers())); + return true; +} + +namespace { +void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) { + window_dim->set_padding_low(window_dim->padding_low() + delta); +} + +void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { + window_dim->set_padding_high(window_dim->padding_high() + delta); +} +} // namespace + +bool PadInsertion::CanonicalizeBackwardFilterConvolution( + HloInstruction* backward_conv) { + if (window_util::HasEvenPadding(backward_conv->window())) { + return false; + } + + // A backward filter convolution with uneven padding can be canonicalized to + // one with even padding by padding the activations (input) beforehand. For + // example, + // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2) + // is equivalent to + // ABCD0 = Pad(ABCD, padding_high=1) + // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) + // We choose the lesser of padding_low and padding_high as the new padding. + HloInstruction* transpose = backward_conv->fused_expression_root(); + HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* input = backward_conv->mutable_operand(0); + Window new_forward_conv_window = forward_conv->window(); + Window new_backward_conv_window = backward_conv->window(); + // input_padding_config is the config of the kPad to be inserted. + PaddingConfig input_padding_config = + MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); + ConvolutionDimensionNumbers forward_conv_dnums = + forward_conv->convolution_dimension_numbers(); + ConvolutionDimensionNumbers backward_conv_dnums = + backward_conv->convolution_dimension_numbers(); + for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { + int64 padding_low = backward_conv->window().dimensions(i).padding_low(); + int64 padding_high = backward_conv->window().dimensions(i).padding_high(); + if (padding_low < 0 || padding_high < 0) { + // TODO(b/32744257): The following canonicalization wouldn't remove + // negative padding in a backward convolution, and would therefore cause + // cuDNN convolution (which doesn't support negative padding) to fail. + return false; + } + // If the backward convolution has uneven padding on the activations, we + // move some padding on the larger end to "internal" padding, so that the + // backward convolution produces larger weight gradients which get sliced + // later. Therefore, the amount of new padding (low or high) is the minimum + // of the amount of old padding low and old padding high. + int64 new_conv_padding = std::min(padding_low, padding_high); + int64 dim = backward_conv_dnums.spatial_dimensions(i); + input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( + padding_low - new_conv_padding); + input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( + padding_high - new_conv_padding); + + // Since we move some padding from the backward convolution to the kPad, we + // need to accordingly reduce the padding amount of the backward convolution + // and its inner forward convolution. + IncreasePaddingLowBy(-(padding_low - new_conv_padding), + new_backward_conv_window.mutable_dimensions(i)); + IncreasePaddingHighBy(-(padding_high - new_conv_padding), + new_backward_conv_window.mutable_dimensions(i)); + IncreasePaddingLowBy(-(padding_low - new_conv_padding), + new_forward_conv_window.mutable_dimensions(i)); + IncreasePaddingHighBy(-(padding_high - new_conv_padding), + new_forward_conv_window.mutable_dimensions(i)); + } + + // Create a new backward convolution replacing the old one. + HloComputation* computation = backward_conv->parent(); + HloInstruction* output = backward_conv->mutable_operand(1); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padded_input = + computation->AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(input->shape(), padding->shape(), + input_padding_config) + .ConsumeValueOrDie(), + input, padding, input_padding_config)); + + HloInstruction* new_forward_conv = + computation->AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape( + padded_input->shape(), output->shape(), new_forward_conv_window, + forward_conv_dnums) + .ConsumeValueOrDie(), + padded_input, output, new_forward_conv_window, forward_conv_dnums)); + + HloInstruction* new_transpose = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(new_forward_conv->shape(), + transpose->dimensions()) + .ConsumeValueOrDie(), + new_forward_conv, transpose->dimensions())); + + // Fuse the new forward convolution and the new transpose to the new backward + // convolution. + HloInstruction* new_backward_conv = + computation->CreateFusionInstructionForBackwardConvolution( + {new_transpose, new_forward_conv}, + HloInstruction::FusionKind::kConvBackwardFilter, + new_backward_conv_window, backward_conv_dnums); + computation->ReplaceInstruction(backward_conv, new_backward_conv); + return true; +} + +bool PadInsertion::CanonicalizeBackwardInputConvolution( + HloInstruction* backward_conv) { + if (window_util::HasEvenPadding(backward_conv->window())) { + return false; + } + + HloInstruction* forward_conv = backward_conv->fused_expression_root(); + HloInstruction* reverse_filter = forward_conv->mutable_operand(1); + Window new_forward_conv_window = forward_conv->window(); + Window new_backward_conv_window = backward_conv->window(); + ConvolutionDimensionNumbers forward_conv_dnums = + forward_conv->convolution_dimension_numbers(); + ConvolutionDimensionNumbers backward_conv_dnums = + backward_conv->convolution_dimension_numbers(); + for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { + int64 padding_low = backward_conv->window().dimensions(i).padding_low(); + int64 padding_high = backward_conv->window().dimensions(i).padding_high(); + if (padding_low < 0 || padding_high < 0) { + // TODO(b/32744257): The following canonicalization wouldn't remove + // negative padding in a backward convolution, and would therefore cause + // cuDNN convolution (which doesn't support negative padding) to fail. + return false; + } + // If the backward convolution has uneven padding on the activations, we + // move some padding on the larger end to "internal" padding, so that the + // backward convolution produces larger activations which get sliced later. + // + // For example, suppose we have a non-canonical HLO + // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1)) + // where the amount of padding low is larger, we can canonicalize it to + // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) + // [A] = Slice([B A]) + // For consistency, we need to increase the low padding of the inner + // convolution by 1 as well because the input is larger now. + if (padding_low > padding_high) { + IncreasePaddingLowBy(padding_high - padding_low, + new_backward_conv_window.mutable_dimensions(i)); + IncreasePaddingLowBy(padding_low - padding_high, + new_forward_conv_window.mutable_dimensions(i)); + } else if (padding_low < padding_high) { + IncreasePaddingHighBy(padding_low - padding_high, + new_backward_conv_window.mutable_dimensions(i)); + IncreasePaddingHighBy(padding_high - padding_low, + new_forward_conv_window.mutable_dimensions(i)); + } + } + + // Create a new backward convolution replacing the old one. + HloComputation* computation = backward_conv->parent(); + HloInstruction* output = backward_conv->mutable_operand(0); + HloInstruction* filter = backward_conv->mutable_operand(1); + HloInstruction* new_reverse_filter = + computation->AddInstruction(HloInstruction::CreateReverse( + filter->shape(), filter, reverse_filter->dimensions())); + HloInstruction* new_forward_conv = + computation->AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape( + output->shape(), new_reverse_filter->shape(), + new_forward_conv_window, forward_conv_dnums) + .ConsumeValueOrDie(), + output, new_reverse_filter, new_forward_conv_window, + forward_conv_dnums)); + HloInstruction* new_backward_conv = + computation->CreateFusionInstructionForBackwardConvolution( + {new_forward_conv, new_reverse_filter}, + HloInstruction::FusionKind::kConvBackwardInput, + new_backward_conv_window, backward_conv_dnums); + + // Slice the new backward convolution. + // + // Initialize start_indices and limit_indices as no slicing. + std::vector start_indices(new_backward_conv->shape().dimensions_size(), + 0LL); + std::vector limit_indices( + new_backward_conv->shape().dimensions().begin(), + new_backward_conv->shape().dimensions().end()); + for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { + int64 padding_low = backward_conv->window().dimensions(i).padding_low(); + int64 padding_high = backward_conv->window().dimensions(i).padding_high(); + int64 dim = backward_conv_dnums.spatial_dimensions(i); + if (padding_low > padding_high) { + // If the amount of low padding (of the old backward convolution) is + // larger, we internally pad the low end of the activations and slice + // internal padding out here. + start_indices[dim] += padding_low - padding_high; + } else if (padding_low < padding_high) { + // If the amount of high padding is larger, we slice out the internal + // padding on the high end. + limit_indices[dim] -= padding_high - padding_low; + } + } + + // Replace the old backward convolution with the slice. + CHECK(ShapeUtil::Compatible( + ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, + limit_indices) + .ConsumeValueOrDie(), + backward_conv->shape())); + computation->ReplaceWithNewInstruction( + backward_conv, + HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, + start_indices, limit_indices)); + return true; +} + +StatusOr PadInsertion::Run(HloModule* module) { + bool changed = false; + for (HloInstruction* instruction : + module->entry_computation()->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kConvolution) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (instruction->opcode() == HloOpcode::kFusion) { + switch (instruction->fusion_kind()) { + case HloInstruction::FusionKind::kConvBackwardFilter: + changed |= CanonicalizeBackwardFilterConvolution(instruction); + break; + case HloInstruction::FusionKind::kConvBackwardInput: + changed |= CanonicalizeBackwardInputConvolution(instruction); + break; + default: + break; + } + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h new file mode 100644 index 0000000000..ec517ed8e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { +namespace gpu { + +// An HLO pass that canonicalizes convolution instructions for GPU codegen. It +// inserts Pad instructions before Convolution instructions with uncanonicalized +// padding, so that they can be lowered to cuDNN convolution. +class PadInsertion : public HloPass { + public: + PadInsertion() : HloPass("pad insertion") {} + + StatusOr Run(HloModule* module) override; + + private: + // Returns if any changes are made to the parent computation. + bool CanonicalizeForwardConvolution(HloInstruction* conv); + bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); + bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000..65610b0995 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "external/llvm/include/llvm/IR/Intrinsics.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + BodyEmitter body_emitter, const Shape& shape, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(body_emitter, shape, ir_builder), + launch_dimensions_(launch_dimensions) {} + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + launch_dimensions_(launch_dimensions) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock() { + // Emit the following code in LLVM IR: + // linear_index = blockIdx.x * blockDim.x + threadIdx.x; + // if (linear_index < num_elements) { + // array_index = LinearIndexToMultidimensionalIndex(shape_, linear_index); + // ... + // } + + // Per the PTX documentation: + // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x" + // + // %nctaid.x is currently specified as 2147483647. + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); + llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), + static_cast(block_id)); + block_id = + ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id"); + + // Per the PTX documentation: + // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" + // + // %ntid.x is currently specified as 1024. + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); + llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), + static_cast(thread_id)); + thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), + "thread_id"); + + llvm::Value* linear_index = ir_builder_->CreateAdd( + ir_builder_->CreateMul( + block_id, + ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", + /*HasNUW=*/true, /*HasNSW=*/true), + thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); + + auto if_in_bounds = llvm_ir::EmitIfThenElse( + ir_builder_->CreateICmpULT( + linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), + "in_bounds", ir_builder_, false); + + // Set exit_bb_ to the exit block of the if structure. + exit_bb_ = if_in_bounds.after_block; + CHECK_NE(nullptr, exit_bb_); + + // Set IR builder insertion point to the body of the if structure. + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + return llvm_ir::IrArray::Index(linear_index, shape_, ir_builder_); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h new file mode 100644 index 0000000000..73ca28cd84 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace gpu { + +// Emits a parallel loop for every element in the given array shape. This loop +// emitted will be executed by multiple threads in parallel. Therefore, each +// thread instance of the loop iterates over part of the array, and they +// collectively iterates over the entire array. +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + // `thread_count` is the number of threads to parallelize the loop on. + // The meanings of other parameters are the same as LoopEmitter. + ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder); + // Constructs a ParallelLoopEmitter from an element generator that generates + // each element of the given target array. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder); + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock() override; + + private: + // The thread and block dimension to parallelize the loop on. + const LaunchDimensions launch_dimensions_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc new file mode 100644 index 0000000000..d0d2deee24 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" + +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +std::ostream& operator<<(std::ostream& out, + const LaunchDimensions& launch_dims) { + out << tensorflow::strings::Printf("[block: %lld, thread: %lld]", + launch_dims.block_count(), + launch_dims.threads_per_block()); + return out; +} + +// Calculates the launch dimensions used to invoke `hlo`. +LaunchDimensions CalculateLaunchDimensions( + const Shape& shape, const se::DeviceDescription& device_desc, + PartitionStrategy partition_strategy) { + int64 warp_size = device_desc.threads_per_warp(); + + int64 num_elements = ShapeUtil::ElementsIn(shape); + if (num_elements <= 1) { + return LaunchDimensions(); + } + + // Calculate the number of threads per block. + // Initialize threads_per_block as the threads-per-block limit. + int64 threads_per_block = device_desc.threads_per_block_limit(); + VLOG(2) << "Initial # of threads per block = " << threads_per_block; + + if (partition_strategy == PartitionStrategy::kLatency) { + // Limit the thread count to allow maximum number of registers per thread. + // TODO(b/28560520): We don't have to assume the emitted kernel will use up + // all the registers. We could use ptxas to examine the actual number of + // register used, and set the thread count accordingly. + int64 threads_per_block_limit_due_to_registers = + device_desc.registers_per_core_limit() / + device_desc.registers_per_thread_limit(); + CHECK_NE(0, threads_per_block_limit_due_to_registers); + if (threads_per_block_limit_due_to_registers < threads_per_block) { + threads_per_block = + // Make `threads_per_block` a multiple of warp size to use GPU + // efficiently. + warp_size * + std::max(1LL, threads_per_block_limit_due_to_registers / warp_size); + VLOG(2) << "Update # of threads per block due to register pressure = " + << threads_per_block; + } + } + + if (num_elements < threads_per_block) { + threads_per_block = num_elements; + VLOG(2) << "Update # of threads per block to the element count (" + << threads_per_block << ") because the latter is smaller."; + } + + // Calculate the block count. We copy the strategy used by Eigen: + // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h + int64 block_count = CeilOfRatio(num_elements, threads_per_block); + VLOG(2) << tensorflow::strings::Printf( + "Initialized the block count to ceil(# of elements / threads per " + "block) = ceil(%lld/%lld) = %lld", + num_elements, threads_per_block, block_count); + + return LaunchDimensions(block_count, threads_per_block); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h new file mode 100644 index 0000000000..8ac4c59966 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Algorithms and data structures for partition assignment. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARTITION_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARTITION_ASSIGNMENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +enum class PartitionStrategy { + // Optimized for latency by allowing maximum number of registers per thread. + kLatency, + // Optimized for throughtput. This may limit registers per thread and cause + // longer latency. + kThroughput +}; + +// Encapsulates the launch dimensions of a kernel, e.g., the block count and the +// number of threads per block. +class LaunchDimensions { + public: + // The default constructor creates a launch dimension that indicate + // single-threaded execution. + LaunchDimensions() : block_count_(1), threads_per_block_(1) {} + + LaunchDimensions(int64 block_count, int64 threads_per_block) + : block_count_(block_count), threads_per_block_(threads_per_block) {} + + bool IsSinglethreaded() const { + return block_count_ == 1 && threads_per_block_ == 1; + } + + int64 block_count() const { return block_count_; } + int64 threads_per_block() const { return threads_per_block_; } + + private: + int64 block_count_; + int64 threads_per_block_; +}; + +std::ostream& operator<<(std::ostream& out, + const LaunchDimensions& launch_dims); + +LaunchDimensions CalculateLaunchDimensions( + const Shape& shape, + const perftools::gputools::DeviceDescription& device_desc, + PartitionStrategy partition_strategy = PartitionStrategy::kLatency); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARTITION_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc new file mode 100644 index 0000000000..d8a43091d4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +SequentialThunk::SequentialThunk(std::vector>&& thunks, + const HloInstruction* hlo) + : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} + +tensorflow::Status SequentialThunk::Initialize( + const GpuExecutable& executable) { + for (auto& thunk : thunks_) { + TF_RETURN_IF_ERROR(thunk->Initialize(executable)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status SequentialThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + for (const auto& thunk : thunks_) { + TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + } + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h new file mode 100644 index 0000000000..32c5b748ab --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_ + +#include + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A thunk that wraps a list of sub-thunks. Executing this thunk executes all +// the sub-thunks sequentially. This is useful to implement instructions that +// require multiple kernel launches or library calls. +class SequentialThunk : public Thunk { + public: + SequentialThunk(std::vector>&& thunks, + const HloInstruction* hlo); + SequentialThunk(const SequentialThunk&) = delete; + SequentialThunk& operator=(const SequentialThunk&) = delete; + + const std::vector>& thunks() const { return thunks_; } + + tensorflow::Status Initialize(const GpuExecutable& executable) override; + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + // The list of sub-thunks. + std::vector> thunks_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc new file mode 100644 index 0000000000..5065e7aedd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" + +#include "tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" + +namespace xla { +namespace gpu { + +bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { + return hlo_to_stream_number_.count(&hlo); +} + +int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { + return FindOrDie(hlo_to_stream_number_, &hlo); +} + +void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo, + int stream_no) { + CHECK_GE(stream_no, 0); + if (stream_no >= stream_count_) { + stream_count_ = stream_no + 1; + } + InsertOrDie(&hlo_to_stream_number_, hlo, stream_no); + VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString(); +} + +namespace { + +// Returns whether the two HLOs can run concurrently, i.e., neither is a +// transitive consumer of the other. +bool CanRunConcurrently( + const HloInstruction& a, const HloInstruction& b, + const HloComputation::ReachabilityMap& transitive_operands) { + return !transitive_operands.IsConnected(&a, &b); +} + +// Returns which existing stream to assign to `hlo`, or -1 if a stream is not +// needed. `stream_assignment` is the existing stream assignment for all +// instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that +// are topologically before `hlo`. +int ComputeStreamToAssign( + const HloInstruction& hlo, const StreamAssignment& stream_assignment, + const HloComputation::ReachabilityMap& transitive_operands, + const std::vector& seen_gemms) { + if (hlo.opcode() == HloOpcode::kParameter || + hlo.opcode() == HloOpcode::kConstant) { + // kParameter and kConstant do not need a thunk. + return -1; + } + + legacy_flags::StreamAssignmentFlags* flags = + legacy_flags::GetStreamAssignmentFlags(); + if (flags->xla_gpu_disable_multi_streaming) { + return 0; + } + + if (!ImplementedAsGemm(hlo)) { + // If `hlo` is not implemented as a GEMM, keep it close to its operands to + // avoid excessive synchronization. + int stream_no = -1; + for (const auto* operand : hlo.operands()) { + if (stream_assignment.HasStreamAssigned(*operand)) { + stream_no = + std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand)); + } + } + if (stream_no == -1) { + stream_no = 0; + } + return stream_no; + } + + // Assign different streams to concurrent GEMMs. The code below uses a + // greedy approach. First, we compute as forbidden_stream_numbers the + // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign + // `hlo` a different stream. + std::set forbidden_stream_numbers; + for (const auto* seen_gemm : seen_gemms) { + int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); + if (!forbidden_stream_numbers.count(stream_no) && + CanRunConcurrently(*seen_gemm, hlo, transitive_operands)) { + forbidden_stream_numbers.insert(stream_no); + } + } + + for (int stream_no = 0; stream_no < stream_assignment.StreamCount(); + ++stream_no) { + if (!forbidden_stream_numbers.count(stream_no)) { + return stream_no; + } + } + return stream_assignment.StreamCount(); +} + +} // namespace + +std::unique_ptr AssignStreams(const HloModule& module) { + auto stream_assignment = MakeUnique(); + const HloComputation& computation = *module.entry_computation(); + std::unique_ptr transitive_operands = + computation.ComputeTransitiveOperands(); + std::vector seen_gemms; + for (const auto* hlo : computation.MakeInstructionPostOrder()) { + int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, + *transitive_operands, seen_gemms); + if (stream_no != -1) { + stream_assignment->AssignStreamToHlo(hlo, stream_no); + } + if (ImplementedAsGemm(*hlo)) { + seen_gemms.push_back(hlo); + } + } + return stream_assignment; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h new file mode 100644 index 0000000000..c2df83aaa4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace gpu { + +// This class encapsulates the assignment of GPU streams to each HloInstruction. +class StreamAssignment { + public: + int StreamCount() const { return stream_count_; } + int StreamNumberForHlo(const HloInstruction& hlo) const; + bool HasStreamAssigned(const HloInstruction& hlo) const; + // `hlo` needs to outlive this StreamAssignment object. + void AssignStreamToHlo(const HloInstruction* hlo, int stream_no); + + private: + int stream_count_ = 1; // At least the main stream. + tensorflow::gtl::FlatMap hlo_to_stream_number_; +}; + +// Assigns GPU streams to instructions in `module`. +std::unique_ptr AssignStreams(const HloModule& module); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc new file mode 100644 index 0000000000..28d47d2b0f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace gpu { + +class StreamAssignmentTest : public HloTestBase { + protected: + // Pre-canned shapes. + Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); +}; + +TEST_F(StreamAssignmentTest, SequentialMatMul) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); + HloInstruction* dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction* dot2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(dot2)); + + std::unique_ptr assignment = AssignStreams(module); + EXPECT_EQ(assignment->StreamNumberForHlo(*dot1), + assignment->StreamNumberForHlo(*dot2)); +} + +TEST_F(StreamAssignmentTest, ConcurrentMatMul) { + HloComputation::Builder builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); + HloInstruction* dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction* dot2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(add)); + + std::unique_ptr assignment = AssignStreams(module); + EXPECT_NE(assignment->StreamNumberForHlo(*dot1), + assignment->StreamNumberForHlo(*dot2)); +} + +TEST_F(StreamAssignmentTest, LatticeMatMul) { + // d00 -- layer 0 + // / \ + // d10 d11 -- layer 1 + // / \ / \ + // d20 d21 d22 -- layer 2 + // \ / \ / + // d30 d31 -- layer 3 + // \ / + // d40 -- layer 4 + HloComputation::Builder builder("entry_computation"); + std::vector params; + for (int i = 0; i < 6; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + } + HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d10 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction* d11 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction* d20 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction* d21 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction* d22 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction* d30 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction* d31 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction* d40 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build(d40)); + + std::unique_ptr assignment = AssignStreams(module); + // The two dots on layer 1 are concurrent. + EXPECT_NE(assignment->StreamNumberForHlo(*d10), + assignment->StreamNumberForHlo(*d11)); + // The three dots on layer 2 are concurrent. + EXPECT_NE(assignment->StreamNumberForHlo(*d20), + assignment->StreamNumberForHlo(*d21)); + EXPECT_NE(assignment->StreamNumberForHlo(*d20), + assignment->StreamNumberForHlo(*d22)); + EXPECT_NE(assignment->StreamNumberForHlo(*d21), + assignment->StreamNumberForHlo(*d22)); + // The two dots on layer 3 are concurrent. + EXPECT_NE(assignment->StreamNumberForHlo(*d30), + assignment->StreamNumberForHlo(*d31)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc new file mode 100644 index 0000000000..3cf5dd021a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" + +#include "tensorflow/compiler/xla/map_util.h" + +namespace xla { +namespace gpu { + +namespace { +int64 RoundUpToAlignment(int64 value) { + // Any address of a variable residing in global memory or returned by one of + // the memory allocation routines from the driver or runtime API is always + // aligned to at least 256 bytes. + // (http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses) + static constexpr int64 kCudaMallocAlignment = 256; + return (value + kCudaMallocAlignment - 1) / kCudaMallocAlignment * + kCudaMallocAlignment; +} +} // namespace + +TempBufferOffsets::TempBufferOffsets( + const BufferAssignment& buffer_assignment) { + total_size_of_temp_buffers_ = 0; + for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) { + const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + if (allocation.IsPreallocatedTempBuffer()) { + InsertOrDie(&buffer_index_to_offset_, i, total_size_of_temp_buffers_); + total_size_of_temp_buffers_ += RoundUpToAlignment(allocation.size()); + } + } +} + +int64 TempBufferOffsets::GetOffset(BufferAllocation::Index index) const { + return FindOrDie(buffer_index_to_offset_, index); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h new file mode 100644 index 0000000000..05aca99bf3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_ + +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +// GpuExecutable merges all temporary buffers into one memory block. This class +// stores the offset of each temporary buffer in that memory block. +class TempBufferOffsets { + public: + explicit TempBufferOffsets(const BufferAssignment& buffer_assignment); + + int64 GetOffset(BufferAllocation::Index index) const; + int64 TotalSizeInBytes() const { return total_size_of_temp_buffers_; } + + private: + std::map buffer_index_to_offset_; + + // The total size of all temporary buffers. This includes paddings that are + // necessary for alignment. + int64 total_size_of_temp_buffers_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h new file mode 100644 index 0000000000..3ced348400 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +class GpuExecutable; + +// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the +// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. +// +// Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable +// to initialize and execute the invocation respectively. Its subclasses are +// supposed to override these interfaces to launch a generated kernel or call an +// external library function (such as operations in cuBLAS). +// +// This is thread-compatible. +class Thunk { + public: + enum class Kind { + kConvolution, + kCopy, + kGemm, + kKernel, + kSequential, + kTuple, + kWhile, + }; + + // The hlo_instruction argument is meant to be the instruction this thunk was + // generated from, but Thunk never uses this argument other than to save it + // to Thunk::hlo_instruction, so it can be null. + explicit Thunk(Kind kind, const HloInstruction* hlo_instruction) + : kind_(kind), hlo_instruction_(hlo_instruction) {} + virtual ~Thunk() {} + Thunk(const Thunk&) = delete; + Thunk& operator=(const Thunk&) = delete; + + Kind kind() const { return kind_; } + const HloInstruction* hlo_instruction() const { return hlo_instruction_; } + + // Prepares for executing the thunk. This method is called only once over + // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a + // kernel, which is the same in every execution. + virtual tensorflow::Status Initialize(const GpuExecutable& executable) { + return tensorflow::Status::OK(); + } + + // Execute the kernel for the thunk on the given stream. This method must be + // called after Initialize and can be called multiple times over Thunk's + // lifetime. Stream argument must be non-null. + virtual tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) = 0; + + private: + Kind kind_; + const HloInstruction* hlo_instruction_; +}; + +// A sequence of thunks. +using ThunkSequence = std::vector>; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc new file mode 100644 index 0000000000..8addcd87ea --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -0,0 +1,163 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +void ThunkSchedule::AddDependenciesOnTransitiveOperands( + const Thunk& thunk, const HloInstruction& operand, + const std::unordered_map& hlo_to_thunk) { + if (hlo_to_thunk.count(&operand)) { + // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency + // list if `operand` is assigned to a different stream. As an optimization, + // we skip `operand`'s operands because `operand` depends on them already. + if (stream_assignment_->StreamNumberForHlo(operand) != + stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) { + depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand)); + } + } else { + // If `operand` doesn't need a thunk (e.g. bitcast), continue with its + // operands. + for (const auto* operand_of_operand : operand.operands()) { + AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand, + hlo_to_thunk); + } + } +} + +ThunkSchedule::ThunkSchedule( + std::unique_ptr thunks, + std::unique_ptr stream_assignment, + const std::vector& hlo_total_order) + : thunks_(std::move(thunks)), + stream_assignment_(std::move(stream_assignment)) { + std::unordered_map hlo_to_thunk; + for (const auto& thunk : *thunks_) { + InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); + } + + for (const HloInstruction* hlo : hlo_total_order) { + if (hlo_to_thunk.count(hlo)) { + thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); + } + } + + for (const Thunk* thunk : thunk_total_order_) { + const auto* dst = thunk->hlo_instruction(); + CHECK(stream_assignment_->HasStreamAssigned(*dst)); + for (const auto* src : dst->operands()) { + AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk); + } + } + + RemoveRedundantDependencyEdges(); + + // Compute `depended_by_`, the inverse of `depends_on_`. + for (const auto& dependency : depends_on_) { + for (const auto* depended : dependency.second) { + depended_by_.insert(depended); + } + } +} + +void ThunkSchedule::RemoveRedundantDependencyEdges() { + std::unordered_map thunk_to_total_order; + for (auto i = 0; i < thunk_total_order_.size(); ++i) { + InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i); + } + + int stream_count = stream_assignment_->StreamCount(); + // S1 S2 + // + // T1<----+ + // | + // T3<--+ | + // | | depends on + // T4 | + // | + // T2-+ + // + // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on + // stream S2. If T2 depends on T1 and T4 depends on T3, and + // order(T1) last_dependency(stream_count, stream_count, -1); + for (const Thunk* dst : thunk_total_order_) { + if (!depends_on_.count(dst)) { + continue; + } + + int dst_stream = + stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction()); + std::list& sources = FindOrDie(depends_on_, dst); + for (auto iter = sources.begin(); iter != sources.end();) { + const Thunk* src = *iter; + // `dst` depends on `src`. + int src_stream = + stream_assignment_->StreamNumberForHlo(*src->hlo_instruction()); + int src_order = FindOrDie(thunk_to_total_order, src); + if (src_order <= last_dependency(dst_stream, src_stream)) { + iter = sources.erase(iter); + } else { + last_dependency(dst_stream, src_stream) = src_order; + ++iter; + } + } + if (sources.empty()) { + depends_on_.erase(dst); + } + } +} + +const std::list& ThunkSchedule::DependsOn( + const Thunk* thunk) const { + if (depends_on_.count(thunk)) { + return FindOrDie(depends_on_, thunk); + } else { + return empty_thunk_list_; + } +} + +string ThunkSchedule::ToString() const { + string result = "Total order:\n"; + for (Thunk* thunk : thunk_total_order_) { + tensorflow::strings::StrAppend(&result, "\t", + thunk->hlo_instruction()->ToString(), "\n"); + } + tensorflow::strings::StrAppend(&result, "Dependencies:\n"); + for (const auto& entry : depends_on_) { + const Thunk* dependent = entry.first; + for (const Thunk* dependency : entry.second) { + tensorflow::strings::StrAppend( + &result, "\t", dependent->hlo_instruction()->name(), " depends on ", + dependency->hlo_instruction()->name(), "\n"); + } + } + return result; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h new file mode 100644 index 0000000000..d3352994f8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +// Encapsulates in which order and on which streams the thunks are executed. A +// schedule contains +// +// 1. A stream assignment indicating which stream each thunk is executed on. +// +// 2. A total order of all thunks. If A is ordered before B and they are +// assigned to the same stream, then A completes before B starts. If A is +// ordered before B and they are on different streams, their actual execution +// order is not determined. +// +// 3. A set of dependency edges. If A and B are scheduled on different streams +// and A has to complete before B starts (e.g. A produces an input of B), then B +// "depends" on A. +class ThunkSchedule { + public: + ThunkSchedule(std::unique_ptr thunks, + std::unique_ptr stream_assignment, + const std::vector& hlo_total_order); + + // Returns the total order of executing all the thunks. + const std::vector& TotalOrder() const { return thunk_total_order_; } + + // Thunks that `thunk` depends on. + const std::list& DependsOn(const Thunk* thunk) const; + // Whether `thunk` is depended by another thunk. + bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); } + + // Delegates to StreamAssignment. + int StreamCount() const { return stream_assignment_->StreamCount(); } + int StreamNumberForHlo(const HloInstruction& hlo) const { + return stream_assignment_->StreamNumberForHlo(hlo); + } + + string ToString() const; + + private: + void RemoveRedundantDependencyEdges(); + + // Adds `operand` and its transitive operands to the dependency list of + // `thunk`. + // + // Precondition: `operand` is a non-trivial (i.e. excluding + // thunk.hlo_instruction() itself) transitive operand of + // thunk.hlo_instruction(). + void AddDependenciesOnTransitiveOperands( + const Thunk& thunk, const HloInstruction& operand, + const std::unordered_map& hlo_to_thunk); + + std::unique_ptr thunks_; + std::vector thunk_total_order_; + + std::unordered_map> depends_on_; + std::set depended_by_; + std::list empty_thunk_list_; + + std::unique_ptr stream_assignment_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc new file mode 100644 index 0000000000..323775b3e8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" + +#include "tensorflow/compiler/xla/util.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +tensorflow::Status TupleThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + std::vector tuple_element_buffer_addresses; + for (BufferAllocation::Index tuple_element_buffer : tuple_element_buffers_) { + tuple_element_buffer_addresses.push_back( + buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); + } + se::DeviceMemory dest_buffer_address( + buffer_allocations.GetDeviceAddress(dest_buffer_)); + + auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + if (!stream + ->ThenMemcpy(&dest_buffer_address, + tuple_element_buffer_addresses.data(), host_size) + .ok()) { + return InternalError( + "Unable to launch MemcpyH2D from %p to %p with size %lu", + tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), + sizeof(void*) * tuple_element_buffer_addresses.size()); + } + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h new file mode 100644 index 0000000000..ca0404286f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TUPLE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TUPLE_THUNK_H_ + +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A thunk that copies the addresses of tuple elements to the buffer of the +// tuple. This avoids emitting kernels that may suffer from the parameter space +// issue (b/31336476). +class TupleThunk : public Thunk { + public: + TupleThunk(tensorflow::gtl::ArraySlice + tuple_element_buffers, + BufferAllocation::Index dest_buffer, + const HloInstruction* hlo_instruction) + : Thunk(Kind::kTuple, hlo_instruction), + tuple_element_buffers_(tuple_element_buffers.begin(), + tuple_element_buffers.end()), + dest_buffer_(dest_buffer) {} + + TupleThunk(const TupleThunk&) = delete; + TupleThunk& operator=(const TupleThunk&) = delete; + + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + std::vector tuple_element_buffers_; + const BufferAllocation::Index dest_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TUPLE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc new file mode 100644 index 0000000000..36883e4920 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/while_thunk.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +WhileThunk::WhileThunk(BufferAllocation::Index condition_result_buffer_index, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, + const HloInstruction* hlo) + : Thunk(Kind::kWhile, hlo), + condition_result_buffer_index_(condition_result_buffer_index), + condition_thunk_sequence_(MakeUnique( + std::move(*condition_thunk_sequence), hlo)), + body_thunk_sequence_( + MakeUnique(std::move(*body_thunk_sequence), hlo)) {} + +tensorflow::Status WhileThunk::Initialize(const GpuExecutable& executable) { + TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); + return tensorflow::Status::OK(); +} + +tensorflow::Status WhileThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + + perftools::gputools::DeviceMemoryBase condition_result_data = + buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); + + while (true) { + // Invoke thunk sequence for while 'condition' computation. + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + + // Copy the result of condition computation and break the loop if 'false'. + bool condition_result; + stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); + if (!stream->BlockHostUntilDone()) { + return InternalError( + "Failed to complete all kernels launched on stream %p", stream); + } + + if (!condition_result) { + break; + } + + // Invoke thunk sequence for while 'body' computation. + TF_RETURN_IF_ERROR( + body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + } + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h new file mode 100644 index 0000000000..1658cdaf87 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_THUNK_H_ + +#include + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// WhileThunk implements the while instruction on GPU by invoking a thunk +// sequence for the while 'condition' computation, and (conditionally) another +// thunk sequence for the while 'body' computation. WhileThunk assumes that +// buffers for the following set of while-related instructions share the same +// allocation: +// init, condition.parameter, body.parameter, body.root, while.result +// WhileThunk synchronizes the stream to test the result of the 'condition' +// computation. +class WhileThunk : public Thunk { + public: + // Constructs a WhileThunk to compute while instruction 'hlo'. + WhileThunk(BufferAllocation::Index condition_result_buffer_index, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, + const HloInstruction* hlo); + WhileThunk(const WhileThunk&) = delete; + WhileThunk& operator=(const WhileThunk&) = delete; + + tensorflow::Status Initialize(const GpuExecutable& executable) override; + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Index condition_result_buffer_index_; + std::unique_ptr condition_thunk_sequence_; + std::unique_ptr body_thunk_sequence_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc new file mode 100644 index 0000000000..7ebc929ced --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -0,0 +1,532 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +namespace { + +// MatcherBase is a base class that provides common functionality for +// sub-classes which match specific target sub-computations (i.e. loop +// induction variable initialization, comparison and update). +// TODO(b/33483676) Use an expression tree to specify computations to pattern +// match for while transformations. +class MatcherBase { + public: + enum State { + ADD, + PRED, + CONST, + COPY, + GTE0, + GTE1, + PARAM, + TUPLE0, + TUPLE1, + WHILE + }; + + // Initializes MatcherBase with 'computation' and initial state 'state'. + explicit MatcherBase(const HloComputation* computation, State state) + : computation_(computation), state_(state) {} + + // Initializes MatcherBase with 'computation', initial state 'state', and + // value for 'tuple_index'. + MatcherBase(const HloComputation* computation, State state, + const int64 tuple_index) + : computation_(computation), state_(state), tuple_index_(tuple_index) {} + virtual ~MatcherBase() {} + + // Overridden by sub-classes to match specific target sub-computations. + // Returns OK if target sub-computation was matched, error status otherwise. + virtual tensorflow::Status Run() = 0; + + // Matches a Constant instruction of integral type, parses its value, and + // stores the value in 'const_value_'. + // Returns OK on success, error status otherwise. + tensorflow::Status MatchConst() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction->opcode() != HloOpcode::kConstant) { + return InvalidArgument("Expected constant instruction."); + } + if (!IsSupportedIntType(instruction->shape())) { + return InvalidArgument("Expected constant of integral type."); + } + const Literal& literal = instruction->literal(); + PrimitiveType type = literal.shape().element_type(); + if (type == S32) { + const_value_ = + static_cast(LiteralUtil::GetFirstElement(literal)); + } else if (type == S64) { + const_value_ = LiteralUtil::GetFirstElement(literal); + } else { + return InvalidArgument("Must use S32 or S64 integral types."); + } + return tensorflow::Status::OK(); + } + + // Matches a Copy instruction. + // Pushes its operand on the stack for subsequent processing. + // Returns OK on success, error status otherwise. + tensorflow::Status MatchCopy() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction->opcode() != HloOpcode::kCopy) { + return InvalidArgument("Expectecd Copy."); + } + stack_.push_back(instruction->operand(0)); + return tensorflow::Status::OK(); + } + + // Matches a GetTupleElement instruction and either parses its 'tuple_index' + // parameter (if not initialized) or compares its 'tuple_index' with the + // previously initialized value. + // Pushes its operand on the stack for subsequent processing. + // Returns OK on success, error status otherwise. + tensorflow::Status MatchGetTupleElement() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction->opcode() != HloOpcode::kGetTupleElement) { + return InvalidArgument("Expected GetTupleElement instruction."); + } + if (!IsSupportedIntType(instruction->shape())) { + return InvalidArgument("GetTupleElement instruction be integral type."); + } + if (tuple_index_ == -1) { + tuple_index_ = instruction->tuple_index(); + } else if (tuple_index_ != instruction->tuple_index()) { + return InvalidArgument("Invalid tuple index"); + } + stack_.push_back(instruction->operand(0)); + return tensorflow::Status::OK(); + } + + // Matches a Parameter instruction and compares it with 'computation_' + // parameter instruction at index 0. + // Returns OK on success, error status otherwise. + tensorflow::Status MatchParameter() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction != computation_->parameter_instruction(0)) { + return InvalidArgument("Expected Parameter instruction."); + } + return tensorflow::Status::OK(); + } + + // Matches a Tuple instruction. + // Pushes operand at 'tuple_index_' on the stack for subsequent processing. + // Returns OK on success, error status otherwise. + tensorflow::Status MatchTuple() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction->opcode() != HloOpcode::kTuple) { + return InvalidArgument("Expected Tuple instruction."); + } + stack_.push_back(instruction->operand(tuple_index_)); + return tensorflow::Status::OK(); + } + + protected: + const HloComputation* computation_; + State state_; + int64 tuple_index_ = -1; + int64 const_value_ = -1; + std::vector stack_; + + private: + bool IsSupportedIntType(const Shape& shape) { + return shape.element_type() == S32 || shape.element_type() == S64; + } + + TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); +}; + +// WhileConditionComputationMatcher matches one of the following two +// target While condition computations: +// +// Case 1: LessThan +// +// PARAM +// | +// | +// GTE0 CONST +// \ / +// \ / +// PRED +// +// +// Case 2: GreaterThan +// +// PARAM +// | +// | +// CONST GTE0 +// \ / +// \ / +// PRED +// +// If we do not successufully match one of the two target cases, we return a +// descriptive error status. +// +// If we do successfully match one of the cases, we parse and store the +// following two pieces of information from the computation: +// +// *) 'tuple_index': +// *) The loop induction variable tuple_index from the GetTupleElement +// instruction of the matched computation. +// *) Used in subsequent matching passes of while init operand and body +// computations to select loop induction variable tuple element. +// +// *) 'loop_limit': +// *) The integral value from Constant root operand in matched computation. +// *) Used as the constant for the loop limit. +// +class WhileConditionComputationMatcher : public MatcherBase { + public: + WhileConditionComputationMatcher(const HloComputation* computation) + : MatcherBase(computation, PRED) { + stack_.push_back(computation_->root_instruction()); + } + + // Loop attempting to match target computation. + tensorflow::Status Run() { + while (!stack_.empty()) { + switch (state_) { + case PRED: { + TF_RETURN_IF_ERROR(MatchPred()); + break; + } + case CONST: { + TF_RETURN_IF_ERROR(MatchConst()); + state_ = GTE0; + break; + } + case GTE0: { + TF_RETURN_IF_ERROR(MatchGetTupleElement()); + state_ = PARAM; + break; + } + case PARAM: { + TF_RETURN_IF_ERROR(MatchParameter()); + break; + } + default: + return InvalidArgument("Unexpected state."); + } + } + return tensorflow::Status::OK(); + } + + int64 loop_limit() const { return const_value_; } + int64 tuple_index() const { return tuple_index_; } + + private: + tensorflow::Status MatchPred() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + // Push operands in canonical order: GetTupleElement, Constant. + if (instruction->opcode() == HloOpcode::kLt) { + stack_.push_back(instruction->operand(0)); + stack_.push_back(instruction->operand(1)); + } else if (instruction->opcode() == HloOpcode::kGt) { + stack_.push_back(instruction->operand(1)); + stack_.push_back(instruction->operand(0)); + } else { + return InvalidArgument("Condition must be LT or GT."); + } + state_ = CONST; + return tensorflow::Status::OK(); + } + + TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); +}; + +// WhileInitOperandMatcher matches one of the following two target while +// init operand sub-computations: +// +// Case 1: No copy. +// +// CONST // Tuple.operand(tuple_index) +// | +// TUPLE0 // While.operand(0) +// | +// WHILE +// +// Case 2: With copy. +// +// CONST // Tuple1.operand(tuple_index) +// | +// TUPLE1 // GetTupleElement.operand(0) +// | +// GTE0 // Copy.operand(0) +// | +// COPY // Tuple0.operand(tuple_index) +// | +// TUPLE0 // While.operand(0) +// | +// While +// +class WhileInitOperandMatcher : public MatcherBase { + public: + WhileInitOperandMatcher(const HloInstruction* while_hlo, + const int64 tuple_index) + : MatcherBase(while_hlo->parent(), WHILE, tuple_index) { + stack_.push_back(while_hlo); + } + + // Loop attempting to match target computation. + tensorflow::Status Run() { + while (!stack_.empty()) { + switch (state_) { + case WHILE: { + TF_RETURN_IF_ERROR(MatchWhile()); + break; + } + case TUPLE0: { + TF_RETURN_IF_ERROR(MatchTuple()); + TF_RETURN_IF_ERROR(PostMatchTuple()); + break; + } + case TUPLE1: { + TF_RETURN_IF_ERROR(MatchTuple()); + state_ = CONST; + break; + } + case CONST: { + TF_RETURN_IF_ERROR(MatchConst()); + break; + } + case COPY: { + TF_RETURN_IF_ERROR(MatchCopy()); + state_ = GTE0; + break; + } + case GTE0: { + TF_RETURN_IF_ERROR(MatchGetTupleElement()); + state_ = TUPLE1; + break; + } + default: + return InvalidArgument("Unexpected state."); + } + } + return tensorflow::Status::OK(); + } + + int64 loop_start() const { return const_value_; } + + private: + tensorflow::Status MatchWhile() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction->opcode() != HloOpcode::kWhile) { + return InvalidArgument("While init match expected while instruction."); + } + // Push while 'init' operand. + stack_.push_back(instruction->operand(0)); + state_ = TUPLE0; + return tensorflow::Status::OK(); + } + + tensorflow::Status PostMatchTuple() { + // Transition to the next state based on matched tuple operand. + const HloInstruction* operand = stack_.back(); + if (operand->opcode() == HloOpcode::kConstant) { + state_ = CONST; + } else if (operand->opcode() == HloOpcode::kCopy) { + state_ = COPY; + } else { + return InvalidArgument("Expected constant or copy tuple operand."); + } + return tensorflow::Status::OK(); + } + + TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); +}; + +// WhileBodyComputationMatcher matches one of the following two target +// sub-computations: +// +// Case 1: +// +// PARAM +// | +// CONST GTE1 +// \ / +// ADD // Tuple.operand(tuple_index). +// | +// TUPLE0 (root) +// +// Case 2: +// +// PARAM +// | +// CONST GTE1 +// \ / +// ADD // Tuple.operand(tuple_index). +// | +// TUPLE1 +// | +// GTE0 +// | +// COPY +// | +// TUPLE0 (root) +// +// Note that the induction variable tuple element can have multiple users +// in the while loop body computation, but one update path. +// Matching proceeds from the root[tuple_index] to param[tuple_index]. +// +class WhileBodyComputationMatcher : public MatcherBase { + public: + WhileBodyComputationMatcher(const HloComputation* computation, + const int64 tuple_index) + : MatcherBase(computation, TUPLE0, tuple_index) { + stack_.push_back(computation_->root_instruction()); + } + + // Loop attempting to match target computation. + tensorflow::Status Run() { + while (!stack_.empty()) { + switch (state_) { + case TUPLE0: { + TF_RETURN_IF_ERROR(MatchTuple()); + TF_RETURN_IF_ERROR(PostMatchTuple()); + break; + } + case TUPLE1: { + TF_RETURN_IF_ERROR(MatchTuple()); + state_ = ADD; + break; + } + case ADD: { + TF_RETURN_IF_ERROR(MatchAdd()); + break; + } + case CONST: { + TF_RETURN_IF_ERROR(MatchConst()); + state_ = GTE1; + break; + } + case COPY: { + TF_RETURN_IF_ERROR(MatchCopy()); + state_ = GTE0; + break; + } + case GTE0: { + TF_RETURN_IF_ERROR(MatchGetTupleElement()); + state_ = TUPLE1; + break; + } + case GTE1: { + TF_RETURN_IF_ERROR(MatchGetTupleElement()); + state_ = PARAM; + break; + } + case PARAM: { + TF_RETURN_IF_ERROR(MatchParameter()); + break; + } + default: + return InvalidArgument("Unexpected state."); + } + } + return tensorflow::Status::OK(); + } + + int64 loop_increment() const { return const_value_; } + + private: + tensorflow::Status MatchAdd() { + const HloInstruction* instruction = stack_.back(); + stack_.pop_back(); + if (instruction->opcode() != HloOpcode::kAdd) { + return InvalidArgument("Expected Add induction variable update."); + } + // Push in canonical order: GetTupleElement, Constant. + if (instruction->operand(0)->opcode() == HloOpcode::kConstant && + instruction->operand(1)->opcode() == HloOpcode::kGetTupleElement) { + stack_.push_back(instruction->operand(1)); + stack_.push_back(instruction->operand(0)); + } else if (instruction->operand(1)->opcode() == HloOpcode::kConstant && + instruction->operand(0)->opcode() == + HloOpcode::kGetTupleElement) { + stack_.push_back(instruction->operand(0)); + stack_.push_back(instruction->operand(1)); + } else { + return InvalidArgument("Invalid types for Add operands"); + } + state_ = CONST; + return tensorflow::Status::OK(); + } + + tensorflow::Status PostMatchTuple() { + // Transition to the next state based on matched tuple operand. + const HloInstruction* operand = stack_.back(); + if (operand->opcode() == HloOpcode::kAdd) { + state_ = ADD; + } else if (operand->opcode() == HloOpcode::kCopy) { + state_ = COPY; + } else { + return InvalidArgument("Expected add or copy tuple operand."); + } + return tensorflow::Status::OK(); + } + + TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); +}; + +} // namespace + +StatusOr> CanTransformWhileToFor( + const HloInstruction* while_hlo) { + if (while_hlo->opcode() != HloOpcode::kWhile) { + return InvalidArgument("Expected While instruction."); + } + + WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); + TF_RETURN_IF_ERROR(cond_matcher.Run()); + + WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); + TF_RETURN_IF_ERROR(init_matcher.Run()); + + WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), + cond_matcher.tuple_index()); + TF_RETURN_IF_ERROR(body_matcher.Run()); + + // Check for valid For loop parameters. + if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { + return InvalidArgument("Loop start must be less than loop limit."); + } + if (body_matcher.loop_increment() <= 0) { + return InvalidArgument("Loop increment must greater than zero."); + } + return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), + body_matcher.loop_increment()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h new file mode 100644 index 0000000000..a4f527fce0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace gpu { + +// Runs an analysis of the while loop instruction 'while_hlo' (and its +// associated sub-computations) to determine if it can be transformed into an +// equivalent "for" loop with the following "for" loop parameters: +// +// *) 'loop_start': loop induction variable starting value. +// *) 'loop_limit': loop induction variable limit value. +// *) 'loop_increment': loop induction variable per-iteration increment value. +// +// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. +// The values in the returned tuple are values extracted from the 'while_hlo' +// operand (and its sub-computations) during analysis. +// Returns an error status on failure. +StatusOr> CanTransformWhileToFor( + const HloInstruction* while_hlo); + +} // namespace gpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc new file mode 100644 index 0000000000..799b30d21e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -0,0 +1,218 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class WhileTransformerTest : public HloTestBase { + protected: + WhileTransformerTest() + : module_(TestName()), + induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), + data_shape_(ShapeUtil::MakeShape(F32, {8})), + loop_state_shape_(ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_})), + condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} + + std::unique_ptr BuildConditionComputation( + const int64 tuple_index, const int64 limit) { + auto builder = HloComputation::Builder(TestName() + ".Condition"); + auto limit_const = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + limit_const->shape(), loop_state, tuple_index)); + builder.AddInstruction( + HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, + induction_variable, limit_const)); + return builder.Build(); + } + + std::unique_ptr BuildBodyComputation( + const int64 ind_var_tuple_index, const int64 data_tuple_index, + const int64 increment, bool insert_copies = false) { + auto builder = HloComputation::Builder(TestName() + ".Body"); + // Create param instruction to access loop state. + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + // Update the induction variable GTE(ind_var_tuple_index). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, ind_var_tuple_index)); + auto inc = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(increment))); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); + // Update data GTE(data_tuple_index). + auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + data_shape_, loop_state, data_tuple_index)); + // Use 'induction_variable' in computation with no path to output tuple. + auto update = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kAdd, data, update)); + // Create output Tuple. + auto tuple0 = + ind_var_tuple_index == 0 + ? builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})) + : builder.AddInstruction(HloInstruction::CreateTuple({add1, add0})); + if (insert_copies) { + InsertTupleElementCopies(ind_var_tuple_index, tuple0, &builder); + } + return builder.Build(); + } + + HloInstruction* BuildWhileInstruction(HloComputation* condition, + HloComputation* body, + const int64 ind_var_tuple_index, + const int64 ind_var_init, + bool insert_copies = false) { + auto builder = HloComputation::Builder(TestName() + ".While"); + auto induction_var_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto loop_state_init = + ind_var_tuple_index == 0 + ? builder.AddInstruction( + HloInstruction::CreateTuple({induction_var_init, data_init})) + : builder.AddInstruction( + HloInstruction::CreateTuple({data_init, induction_var_init})); + if (insert_copies) { + loop_state_init = InsertTupleElementCopies(ind_var_tuple_index, + loop_state_init, &builder); + } + auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition, body, loop_state_init)); + module_.AddEntryComputation(builder.Build()); + return while_hlo; + } + + HloInstruction* InsertTupleElementCopies(const int64 ind_var_tuple_index, + HloInstruction* tuple0, + HloComputation::Builder* builder) { + auto gte0 = builder->AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, tuple0, ind_var_tuple_index)); + const int64 gte1_tuple_index = ind_var_tuple_index == 0 ? 1 : 0; + auto gte1 = builder->AddInstruction(HloInstruction::CreateGetTupleElement( + data_shape_, tuple0, gte1_tuple_index)); + // Insert copies. + auto copy0 = builder->AddInstruction( + HloInstruction::CreateUnary(gte0->shape(), HloOpcode::kCopy, gte0)); + auto copy1 = builder->AddInstruction( + HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1)); + + return ind_var_tuple_index == 0 + ? builder->AddInstruction( + HloInstruction::CreateTuple({copy0, copy1})) + : builder->AddInstruction( + HloInstruction::CreateTuple({copy1, copy0})); + } + + HloModule module_; + Shape induction_variable_shape_; + Shape data_shape_; + Shape loop_state_shape_; + Shape condition_result_shape_; +}; + +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(0, 10)); + auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); + auto result = + gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 0, 0)); + EXPECT_TRUE(result.ok()); + auto tuple = result.ConsumeValueOrDie(); + EXPECT_EQ(0, std::get<0>(tuple)); + EXPECT_EQ(10, std::get<1>(tuple)); + EXPECT_EQ(1, std::get<2>(tuple)); +} + +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0_WithBodyCopies) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(0, 10)); + auto body = + module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1, true)); + auto result = + gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 0, 0)); + EXPECT_TRUE(result.ok()); + auto tuple = result.ConsumeValueOrDie(); + EXPECT_EQ(0, std::get<0>(tuple)); + EXPECT_EQ(10, std::get<1>(tuple)); + EXPECT_EQ(1, std::get<2>(tuple)); +} + +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0_WithInitCopies) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(0, 10)); + auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); + auto result = gpu::CanTransformWhileToFor( + BuildWhileInstruction(condition, body, 0, 0, true)); + EXPECT_TRUE(result.ok()); + auto tuple = result.ConsumeValueOrDie(); + EXPECT_EQ(0, std::get<0>(tuple)); + EXPECT_EQ(10, std::get<1>(tuple)); + EXPECT_EQ(1, std::get<2>(tuple)); +} + +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(1, 10)); + auto body = module_.AddEmbeddedComputation(BuildBodyComputation(1, 0, 1)); + auto result = + gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 1, 0)); + EXPECT_TRUE(result.ok()); + auto tuple = result.ConsumeValueOrDie(); + EXPECT_EQ(0, std::get<0>(tuple)); + EXPECT_EQ(10, std::get<1>(tuple)); + EXPECT_EQ(1, std::get<2>(tuple)); +} + +TEST_F(WhileTransformerTest, InvalidLoopLimit) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(0, 5)); + auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); + auto result = gpu::CanTransformWhileToFor( + BuildWhileInstruction(condition, body, 0, 10)); + EXPECT_FALSE(result.ok()); + EXPECT_MATCH( + result.status().error_message(), + testing::ContainsRegex("Loop start must be less than loop limit.")); +} + +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { + auto condition = + module_.AddEmbeddedComputation(BuildConditionComputation(0, 10)); + auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, -1)); + auto result = + gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 0, 0)); + EXPECT_FALSE(result.ok()); + EXPECT_MATCH( + result.status().error_message(), + testing::ContainsRegex("Loop increment must greater than zero.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc new file mode 100644 index 0000000000..cd00a41a03 --- /dev/null +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example HLO graph which demonstrates Graphviz dumper for HLO +// computations. When run, pushes the example DOT graph to the Graphviz service +// and prints the URL. Useful for seeing effect of changes to the graph +// generation code. + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Adds a computation to the given HLO module which adds a scalar constant to +// its parameter and returns the result. +HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { + auto builder = + HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); + auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "x_value")); + auto half = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); + builder.AddInstruction(HloInstruction::CreateBinary( + half->shape(), HloOpcode::kAdd, x_value, half)); + return module->AddEmbeddedComputation(builder.Build()); +} + +// Adds a computation to the given HLO module which sums its two parameters and +// returns the result. +HloComputation* ScalarSumComputation(HloModule* module) { + auto builder = HloComputation::Builder("add"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "rhs")); + builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module->AddEmbeddedComputation(builder.Build()); +} + +// Adds a computation to the given HLO module which forwards its argument to a +// kCall instruction which then calls the given computation. +HloComputation* CallForwardingComputation(HloComputation* computation, + HloModule* module) { + auto builder = HloComputation::Builder("call_forward"); + auto arg = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "arg")); + builder.AddInstruction( + HloInstruction::CreateCall(arg->shape(), {arg}, computation)); + return module->AddEmbeddedComputation(builder.Build()); +} + +// Create a large, arbitrary computation with many different kinds of +// instructions. Sets the computation as the entry to an HLO module and returns +// the module. +std::unique_ptr MakeBigGraph() { + auto module = MakeUnique("BigGraph"); + + auto builder = HloComputation::Builder("TestBigGraphvizGraph"); + + // Shapes used in the computation. + auto mshape = ShapeUtil::MakeShape(F32, {3, 5}); + auto vshape = ShapeUtil::MakeShape(F32, {3}); + auto sshape = ShapeUtil::MakeShape(F32, {3}); + + // Create a set of parameter instructions. + auto param_v0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, vshape, "foo")); + auto param_v1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, vshape, "bar")); + auto param_v2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, vshape, "baz")); + auto param_s = + builder.AddInstruction(HloInstruction::CreateParameter(3, sshape, "qux")); + auto param_m = + builder.AddInstruction(HloInstruction::CreateParameter(4, mshape, "zzz")); + + // Add an arbitrary expression of different instructions. + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); + auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( + vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({dot, param_s, clamp})); + auto scalar = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(sshape, tuple, 2)); + auto add_one = AddScalarConstantComputation(1.0, module.get()); + auto rng = builder.AddInstruction( + HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_computation = ScalarSumComputation(module.get()); + builder.AddInstruction( + HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); + auto map1 = builder.AddInstruction( + HloInstruction::CreateMap(sshape, {scalar}, add_one)); + auto map2 = builder.AddInstruction( + HloInstruction::CreateMap(sshape, {map1}, add_one)); + auto map3 = builder.AddInstruction( + HloInstruction::CreateMap(sshape, {map2}, add_one)); + + // Create a fusion instruction containing the chain of map instructions. + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + sshape, HloInstruction::FusionKind::kLoop, map3)); + fusion->FuseInstruction(map2); + fusion->FuseInstruction(map1); + + // Add a random trace instruction. + builder.AddInstruction(HloInstruction::CreateTrace("trace", dot)); + + // Add a call instruction will calls the call-forwarding computation to call + // another computation. + auto call_computation = CallForwardingComputation(add_one, module.get()); + builder.AddInstruction( + HloInstruction::CreateCall(fusion->shape(), {fusion}, call_computation)); + + module->AddEntryComputation(builder.Build()); + return module; +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + auto module = xla::MakeBigGraph(); + + printf("Graph URL: %s\n", + xla::hlo_graph_dumper::DumpGraph( + *module->entry_computation(), "Example computation", + /*show_addresses=*/false, /*show_layouts=*/false) + .c_str()); + return 0; +} diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc new file mode 100644 index 0000000000..4b7d795cc6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -0,0 +1,520 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +std::unique_ptr HloComputation::Builder::Build( + HloInstruction* root_instruction) { + int parameter_count = 0; + for (auto& instruction : instructions_) { + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_count++; + } + } + // If root_instruction is not specified use the last added instruction. + HloInstruction* root = + root_instruction ? root_instruction : last_added_instruction_; + CHECK_NE(nullptr, root); + + return WrapUnique( + new HloComputation(name_, parameter_count, &instructions_, root)); +} + +HloComputation::HloComputation( + const string& name, int parameter_count, + std::vector>* instructions, + HloInstruction* root_instruction) + : name_(name), + root_instruction_(root_instruction), + instruction_name_uniquer_(/*separator=*/".") { + param_instructions_.resize(parameter_count, nullptr); + bool root_found = false; + for (auto& instruction : *instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_no = instruction->parameter_number(); + CHECK_GE(param_no, 0); + CHECK_LT(param_no, param_instructions_.size()); + CHECK_EQ(nullptr, param_instructions_[param_no]); + param_instructions_[param_no] = instruction.get(); + } + root_found |= instruction.get() == root_instruction_; + AddInstructionInternal(std::move(instruction)); + } + CHECK(root_found); +} + +HloInstruction* HloComputation::AddInstruction( + std::unique_ptr instruction) { + CHECK(instruction->opcode() != HloOpcode::kParameter) + << "Parameter instructions cannot be added to a computation after " + << "it has been built"; + return AddInstructionInternal(std::move(instruction)); +} + +HloInstruction* HloComputation::AddInstructionInternal( + std::unique_ptr instruction) { + // Generate a unique name for the instruction. + instruction->set_name( + instruction_name_uniquer_.GetUniqueName(instruction->name())); + instruction->set_parent(this); + HloInstruction* pinst = instruction.get(); + instruction_iterators_[pinst] = + instructions_.insert(instructions_.end(), std::move(instruction)); + return pinst; +} + +void HloComputation::RemoveInstructionAndUnusedOperands( + HloInstruction* instruction) { + CHECK_NE(root_instruction(), instruction); + + CHECK_EQ(0, instruction->user_count()); + CHECK_NE(instruction->opcode(), HloOpcode::kParameter) + << "Cannot remove parameter instructions"; + std::queue remove; + remove.push(instruction); + while (!remove.empty()) { + HloInstruction* item = remove.front(); + remove.pop(); + if (item->user_count() != 0 || item == root_instruction_ || + item->opcode() == HloOpcode::kParameter) { + continue; + } + for (int i = 0; i < item->operand_count(); ++i) { + remove.push(item->mutable_operand(i)); + } + + // If an instruction has the same operand more than once, we must not remove + // it again. + RemoveInstruction(item); + } +} + +bool HloComputation::RemoveInstructionIfFound(HloInstruction* instruction) { + CHECK_NE(instruction->opcode(), HloOpcode::kParameter) + << "Cannot remove parameter instructions"; + CHECK_NE(root_instruction(), instruction) << "cannot remove root instruction"; + CHECK_EQ(0, instruction->user_count()) + << "instruction with users cannot be removed"; + + if (instruction_iterators_.count(instruction) == 0) { + return false; + } + auto inst_it = instruction_iterators_.at(instruction); + (*inst_it)->set_parent(nullptr); + instruction->DetachFromOperands(); + instructions_.erase(inst_it); + return true; +} + +void HloComputation::RemoveInstruction(HloInstruction* instruction) { + CHECK(RemoveInstructionIfFound(instruction)) + << instruction->ToString() << " is not a member of computation " + << name(); +} + +void HloComputation::ReplaceUsesOfInstruction( + HloInstruction* instruction_to_replace, HloInstruction* instruction) { + instruction_to_replace->ReplaceAllUsesWith(instruction); + if (instruction_to_replace == root_instruction()) { + set_root_instruction(instruction); + } +} + +void HloComputation::set_root_instruction( + HloInstruction* new_root_instruction) { + // The shape of the root (ignoring layout) is an invariant of the computation. + CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), + root_instruction_->shape())) + << new_root_instruction->shape().ShortDebugString() + << " is incompatible with " + << root_instruction_->shape().ShortDebugString(); + bool root_found = false; + for (auto& instruction : instructions_) { + if (new_root_instruction == instruction.get()) { + root_found = true; + break; + } + } + DCHECK(root_found); + + root_instruction_ = new_root_instruction; +} + +namespace { + +// Helper class which computes the post order of an expression rooted at a +// particular instruction. +class InstructionPostOrderer : public DfsHloVisitorWithDefault { + public: + // added_instructions is the set of instructions which have already been + // accounted for in the post order in previous invocations of + // GetOrder. Without this mechanism, instructions which are predecessors of + // multiple root instructions of the computation can be added to the post + // order more than once. + static std::list GetOrder( + HloInstruction* root, + tensorflow::gtl::FlatSet* added_instructions) { + InstructionPostOrderer orderer(added_instructions); + TF_CHECK_OK(root->Accept(&orderer)); + return std::move(orderer.post_order_); + } + + private: + explicit InstructionPostOrderer( + tensorflow::gtl::FlatSet* added_instructions) + : added_instructions_(added_instructions) {} + ~InstructionPostOrderer() override {} + + Status DefaultAction(HloInstruction* hlo_instruction) override { + if (added_instructions_->count(hlo_instruction) == 0) { + post_order_.push_back(hlo_instruction); + added_instructions_->insert(hlo_instruction); + } + return Status::OK(); + } + + std::list post_order_; + tensorflow::gtl::FlatSet* added_instructions_; +}; + +// Helper which builds a post order of the HLO call graph. +void ComputeComputationPostOrder( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited, + std::list* post_order) { + if (visited->count(computation) > 0) { + return; + } + + for (auto& instruction : computation->instructions()) { + for (auto& called_computation : instruction->MakeCalledComputationsSet()) { + ComputeComputationPostOrder(called_computation, visited, post_order); + } + } + + visited->insert(computation); + post_order->push_back(computation); + return; +} + +} // namespace + +std::list HloComputation::MakeInstructionPostOrder() const { + std::list post_order; + std::list trace_instructions; + tensorflow::gtl::FlatSet added_instructions; + for (auto& instruction : instructions_) { + if (instruction->opcode() == HloOpcode::kTrace) { + // Trace instructions aren't handled by the DFS visitor. Add trace + // instructions to the post order at the end (necessarily they have no + // users). + trace_instructions.push_back(instruction.get()); + } else if (instruction->users().empty()) { + post_order.splice(post_order.end(), + InstructionPostOrderer::GetOrder(instruction.get(), + &added_instructions)); + } + } + post_order.splice(post_order.end(), trace_instructions); + CHECK_EQ(instructions_.size(), post_order.size()) + << "number of instructions does not match post order size"; + return post_order; +} + +std::list HloComputation::MakeEmbeddedComputationsList() + const { + tensorflow::gtl::FlatSet visited; + std::list post_order; + + // To avoid special handling of this computation, cast away const of + // 'this'. 'this' is immediately removed from the post order after + // construction. + ComputeComputationPostOrder(const_cast(this), &visited, + &post_order); + + // We don't want to include this computation in the post order. + CHECK_EQ(this, post_order.back()); + post_order.pop_back(); + + return post_order; +} + +string HloComputation::ToString() const { + std::ostringstream s; + s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " { \n"; + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + s << " " << instruction->ToString() << "\n"; + if (instruction->opcode() == HloOpcode::kFusion) { + for (const auto& fused_instruction : instruction->fused_instructions()) { + s << " " << fused_instruction->ToString() << "\n"; + } + } + } + s << "}"; + return s.str(); +} + +void HloComputation::FuseInstructionsInto( + tensorflow::gtl::ArraySlice instructions_to_fuse, + HloInstruction* fusion_instruction) { + CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + HloInstruction* root = instructions_to_fuse.front(); + root->ReplaceAllUsesWith(fusion_instruction); + if (root == root_instruction()) { + set_root_instruction(fusion_instruction); + } + RemoveInstruction(root); + for (size_t i = 1; i < instructions_to_fuse.size(); ++i) { + HloInstruction* instruction = instructions_to_fuse[i]; + fusion_instruction->FuseInstruction(instruction); + if (instruction->user_count() == 0) { + RemoveInstruction(instruction); + } + } +} + +HloInstruction* HloComputation::CreateFusionInstruction( + tensorflow::gtl::ArraySlice instructions_to_fuse, + HloInstruction::FusionKind fusion_kind) { + HloInstruction* root = instructions_to_fuse.front(); + HloInstruction* fusion_instruction = AddInstruction( + HloInstruction::CreateFusion(root->shape(), fusion_kind, root)); + FuseInstructionsInto(instructions_to_fuse, fusion_instruction); + return fusion_instruction; +} + +HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution( + tensorflow::gtl::ArraySlice instructions_to_fuse, + HloInstruction::FusionKind fusion_kind, const Window& window, + const ConvolutionDimensionNumbers& conv_dnums) { + CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind || + HloInstruction::FusionKind::kConvBackwardInput == fusion_kind); + HloInstruction* root = instructions_to_fuse.front(); + HloInstruction* fusion_instruction = + AddInstruction(HloInstruction::CreateFusionForBackwardConvolution( + root->shape(), fusion_kind, window, conv_dnums, root)); + FuseInstructionsInto(instructions_to_fuse, fusion_instruction); + return fusion_instruction; +} + +StatusOr HloComputation::DeepCopyTuple( + HloInstruction* instruction) { + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())); + std::vector element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i)); + // Recurse to copy tuple elements. For array elements, insert a kCopy + // because GetTupleElement forwards a pointer to the tuple element buffer. + HloInstruction* element_copy; + if (ShapeUtil::IsTuple(gte->shape())) { + TF_ASSIGN_OR_RETURN(element_copy, DeepCopyTuple(gte)); + } else { + element_copy = AddInstruction( + HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte)); + } + element_copies.push_back(element_copy); + } + + // Gather element copies into a tuple with a new Tuple instruction. + return AddInstruction(HloInstruction::CreateTuple(element_copies)); +} + +StatusOr HloComputation::DeepCopyInstruction( + HloInstruction* instruction) { + if (instruction->parent() != this) { + return FailedPrecondition( + "Can't deep copy instruction %s: instruction is not in computation %s", + instruction->name().c_str(), name().c_str()); + } + + // For tuple instructions, perform a deep copy. For array instructions, copy + // with a kCopy instruction. + if (ShapeUtil::IsTuple(instruction->shape())) { + return DeepCopyTuple(instruction); + } else if (ShapeUtil::IsArray(instruction->shape())) { + return AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +Status HloComputation::AddControlDependency(HloInstruction* predecessor, + HloInstruction* successor) { + TF_RET_CHECK(instruction_iterators_.count(predecessor) > 0); + TF_RET_CHECK(instruction_iterators_.count(successor) > 0); + successor->AddControlPredecessor(predecessor); + return Status::OK(); +} + +ProgramShape HloComputation::ComputeProgramShape() const { + ProgramShape program_shape; + + for (auto* param_instruction : param_instructions_) { + *program_shape.add_parameters() = param_instruction->shape(); + *program_shape.add_parameter_names() = param_instruction->parameter_name(); + } + *program_shape.mutable_result() = root_instruction_->shape(); + + LayoutUtil::ClearLayout(&program_shape); + return program_shape; +} + +bool HloComputation::operator==(const HloComputation& other) const { + std::set> visited; + std::function eq = + [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { + // If are visited but not identical, the recursion should have + // been aborted. So, if are visited at this point, they must be + // identical. + if (visited.count(std::make_pair(a, b)) > 0) return true; + visited.emplace(a, b); + return a->Identical( + *b, eq, [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }); + }; + return eq(root_instruction(), other.root_instruction()); +} + +void HloComputation::ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction) { + ReplaceInstruction(old_instruction, + AddInstruction(std::move(new_instruction))); +} + +void HloComputation::ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + CHECK(ShapeUtil::Compatible(old_instruction->shape(), + new_instruction->shape())); + VLOG(10) << "transformed " << old_instruction->ToString() << " to " + << new_instruction->ToString(); + ReplaceUsesOfInstruction(old_instruction, new_instruction); + RemoveInstructionAndUnusedOperands(old_instruction); +} + +HloComputation::ReachabilityMap::ReachabilityMap( + const std::list& all_instructions) { + const int n = all_instructions.size(); + int next_id = 0; + for (const auto* hlo : all_instructions) { + ids_[hlo] = next_id; + next_id++; + } + DCHECK_EQ(n, ids_.size()); // instructions should be unique + matrix_.Reset(n * n); +} + +void HloComputation::ReachabilityMap::SetReachable(const HloInstruction* a, + const HloInstruction* b) { + const int id_a = FindOrDie(ids_, a); + const int id_b = FindOrDie(ids_, b); + matrix_.set(id_a * ids_.size() + id_b); +} + +bool HloComputation::ReachabilityMap::IsReachable( + const HloInstruction* a, const HloInstruction* b) const { + const int id_a = FindOrDie(ids_, a); + const int id_b = FindOrDie(ids_, b); + return matrix_.get(id_a * ids_.size() + id_b); +} + +bool HloComputation::ReachabilityMap::IsConnected( + const HloInstruction* a, const HloInstruction* b) const { + const int id_a = FindOrDie(ids_, a); + const int id_b = FindOrDie(ids_, b); + return matrix_.get(id_a * ids_.size() + id_b) || + matrix_.get(id_b * ids_.size() + id_a); +} + +void HloComputation::ReachabilityMap::SetReachableAndTransitiveClosure( + const HloInstruction* a, const HloInstruction* b) { + const int id_a = FindOrDie(ids_, a); + const int id_b = FindOrDie(ids_, b); + const int n = ids_.size(); + matrix_.set(id_a * n + id_b); + + // Copy transitive set for b into entries for a + for (int i = 0; i < n; i++) { + if (matrix_.get(id_b * n + i)) { + matrix_.set(id_a * n + i); + } + } +} + +std::unique_ptr +HloComputation::ComputeTransitiveOperands() const { + const auto all = MakeInstructionPostOrder(); + auto result = MakeUnique(all); + + // Fill in the dependency bit matrix + for (const auto* hlo : all) { + for (const HloInstruction* operand : hlo->operands()) { + result->SetReachableAndTransitiveClosure(hlo, operand); + } + } + return result; +} + +Status HloComputation::Accept(DfsHloVisitor* visitor) const { + // Visit all dead roots. + for (auto& instruction : instructions()) { + if (instruction->user_count() == 0 && + instruction.get() != root_instruction()) { + // Call FinishVisit only at the end. + TF_RETURN_IF_ERROR( + instruction->Accept(visitor, /*call_finish_visit=*/false)); + } + } + // Visit root instruction last. + return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); +} + +Status HloComputation::Accept( + const FunctionVisitor::VisitorFunction& visitor_func) const { + FunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h new file mode 100644 index 0000000000..67df5797dc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloModule; + +// Describes a computation at the HLO level. +// +// An HloComputation contains a directed acyclic graph of HLO instructions. The +// computation has a single root instruction which produces the output of the +// computation. +class HloComputation { + public: + // Builder class for HloComputation. + class Builder { + public: + explicit Builder(const string& name) + : name_(name), last_added_instruction_(nullptr) {} + + // Build and return an HloComputation. The parameter root_instruction + // specifies the already-added instruction to use as the root. If + // root_instruction is nullptr then use the last added instruction as the + // root. + std::unique_ptr Build( + HloInstruction* root_instruction = nullptr); + + HloInstruction* AddInstruction( + std::unique_ptr instruction) { + instructions_.push_back(std::move(instruction)); + last_added_instruction_ = instructions_.back().get(); + return last_added_instruction_; + } + + private: + const string name_; + HloInstruction* last_added_instruction_; + std::vector> instructions_; + }; + + // Add an instruction to the computation. The computation takes ownership of + // the instruction. + HloInstruction* AddInstruction(std::unique_ptr instruction); + + // Remove an instruction from the computation. The instruction must have no + // users. Instruction is deallocated with this call. + void RemoveInstruction(HloInstruction* instruction); + + // Remove an instruction from the computation and also transitively any + // operand that has no users post removing an instruction. The instruction + // must have no users. Instruction is deallocated with this call. + void RemoveInstructionAndUnusedOperands(HloInstruction* instruction); + + // Replace all uses of "instruction_to_replace" with "instruction". Also, if + // instruction_to_replace is the root of this computation then the root is set + // to "instruction". Does not remove "instruction_to_replace". + void ReplaceUsesOfInstruction(HloInstruction* instruction_to_replace, + HloInstruction* instruction); + + // Set the root of the computation to the given instruction. The instruction + // must have already been added to the computation and have the same shape as + // the result of the computation. + void set_root_instruction(HloInstruction* instruction); + + // Return the root instruction of the computation. The root instruction is the + // instruction which produces the output of the computation. + HloInstruction* root_instruction() const { return root_instruction_; } + + // Returns the number of parameters for this computation. + int64 num_parameters() const { return param_instructions_.size(); } + + // Returns the parameter instruction for the given parameter number. + HloInstruction* parameter_instruction(int64 param_no) const { + CHECK_GE(param_no, 0); + CHECK_LT(param_no, param_instructions_.size()); + return param_instructions_[param_no]; + } + + const std::vector& parameter_instructions() const { + return param_instructions_; + } + + const string& name() const { return name_; } + + // Return a string representation of the computation. + string ToString() const; + + const std::list>& instructions() const { + return instructions_; + } + + // Add a control dependency between the two instructions in this computation + // so that the 'predecessor' is visited before the 'successor' during the DFS + // traversal of the computation. Returns an error status if either of the + // given instructions does not belong to the current computation. + // + // This is used to enforce an additional ordering requirement that is not + // captured by normal data dependencies, such as ordering among Send or Recv + // operations to avoid deadlock. + Status AddControlDependency(HloInstruction* predecessor, + HloInstruction* successor); + + // Compute and return a post-order of the instructions in the computation. In + // this order, definitions of values always appear before their uses. + std::list MakeInstructionPostOrder() const; + + // Computes and returns the mapping from HLO to its transitive operands. + class ReachabilityMap; + std::unique_ptr ComputeTransitiveOperands() const; + + int64 instruction_count() const { return instructions_.size(); } + + // Creates and returns a list of the embedded computations called by this + // computation. This includes all embedded computations called directly or + // transitively. The embedded computations are sorted such that if computation + // A calls computation B (eg, via a map instruction) then A will appear after + // B in the list. + std::list MakeEmbeddedComputationsList() const; + + // Creates a fusion instruction containing the given instructions. + // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion + // into a library call. Instructions must be in reverse topological order + // (root of the fused expression first). Replaces all uses of the original + // root instruction with the fusion instruction. The original instructions are + // removed if they have no uses after fusion (this is necessarily true for at + // least the root). + HloInstruction* CreateFusionInstruction( + tensorflow::gtl::ArraySlice instructions_to_fuse, + HloInstruction::FusionKind fusion_kind); + + // Creates a fusion instruction that represents a backward convolution. This + // is similar to CreateFusionInstruction but takes window and conv_dnums which + // indicate the window and convolution dimension numbers of the backward + // convolution. + HloInstruction* CreateFusionInstructionForBackwardConvolution( + tensorflow::gtl::ArraySlice instructions_to_fuse, + HloInstruction::FusionKind fusion_kind, const Window& window, + const ConvolutionDimensionNumbers& conv_dnums); + + // Create a deep copy of the given instruction and return the instruction + // producing the copied result. All instructions performing the copy are added + // to the computation. For array-shaped values, this method trivially returns + // a kCopy instruction. For tuple-shaped instructions, the copy is performed + // with a series of kGetTupleElement and kTuple instructions. + StatusOr DeepCopyInstruction(HloInstruction* instruction); + + // Computes and returns the ProgramShape of this computation (shape of + // parameters and result without layout). + ProgramShape ComputeProgramShape() const; + + // Return whether `*this` and `other` are functionally equivalent. + bool operator==(const HloComputation& other) const; + + // Replaces old instruction with newly created instruction. Removes old + // instruction from computation. Updates uses and root instruction. + void ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction); + + // Replace old instruction with new instruction. Updates uses and root + // instruction. Removes old instruction from computation. Precondition: + // old_instruction and new_instruction must have the compatible shapes. + void ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction); + + // Set/get the module containing this computation. + void set_parent(HloModule* module) { parent_ = module; } + const HloModule* parent() const { return parent_; } + + // Visit every node in the computation in DFS post-order with the given + // visitor. This is similar to calling HloInstruction::Accept on the root of + // the computation except this method also visits instructions not reachable + // via the root. The root instruction of the computation is visited last, and + // the visitor's FinishVisit method is called once upon completion (with the + // root instruction as the argument). + Status Accept(DfsHloVisitor* visitor) const; + + // Same as Accept() above, but the visitor is given as a function. + Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; + + private: + explicit HloComputation( + const string& name, int parameter_count, + std::vector>* instructions, + HloInstruction* root_instruction); + + // Internal helper for adding instructions. + HloInstruction* AddInstructionInternal( + std::unique_ptr instruction); + + // Remove an instruction from the computation if found. The instruction must + // have no users. Instruction is deallocated with this call. + // Return whether instruction was found and removed. + bool RemoveInstructionIfFound(HloInstruction* instruction); + + // Fuses HLOs in instructions_to_fuse into fusion_instruction. + // + // Pre-condition: fusion_instruction's opcode is kFusion. + void FuseInstructionsInto( + tensorflow::gtl::ArraySlice instructions_to_fuse, + HloInstruction* fusion_instruction); + + // Internal helper for copying a tuple value. Creates and returns a deep copy + // of the given instruction. The given instruction must be tuple-shaped. + StatusOr DeepCopyTuple(HloInstruction* instruction); + + const string name_; + HloInstruction* root_instruction_; + + // Module containing this computation. + HloModule* parent_ = nullptr; + + // Store instructions in std::list as they can be added and removed + // arbitrarily and we want a stable iteration order. Keep a map from + // instruction pointer to location in the list for fast lookup. + using InstructionList = std::list>; + InstructionList instructions_; + std::unordered_map + instruction_iterators_; + + std::vector param_instructions_; + + // Unique name generator for instruction identifiers. Instruction names should + // be unique per computation and this is enforced when instructions are added + // to the computation. + NameUniquer instruction_name_uniquer_; + + TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); +}; + +class HloComputation::ReachabilityMap { + public: + // Sets up an empty reachable matrix for the full set of + // instructions specified in "all_instructions" + explicit ReachabilityMap(const std::list& all_instructions); + // Sets entry so that IsReachable(a, b) will return true + void SetReachable(const HloInstruction* a, const HloInstruction* b); + + // Sets IsReachable(a_inst, b_inst) as well as IsReachable(a_inst, trans) + // for all "trans" s.t. "IsReachable(b_inst, trans)" is true + void SetReachableAndTransitiveClosure(const HloInstruction* a_inst, + const HloInstruction* b_inst); + + // Returns true if "b" is reachable from "a" + bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; + + // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + + private: + friend class HloComputation; + + // dense id assignment from HloInstruction* to number + tensorflow::gtl::FlatMap ids_; + // matrix_(a,b) is true iff b is reachable from a + tensorflow::core::Bitmap matrix_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc new file mode 100644 index 0000000000..1e0d09b72c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -0,0 +1,311 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +namespace { + +class HloComputationTest : public HloTestBase { + protected: + HloComputationTest() {} + + // Create a computation which takes a scalar and returns its negation. + std::unique_ptr CreateNegateComputation() { + auto builder = HloComputation::Builder("Negate"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + return builder.Build(); + } + + // Creates a computation which calls map with the given computation. + std::unique_ptr CreateMapComputation( + HloComputation* map_computation) { + auto builder = HloComputation::Builder("Map"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map_computation)); + return builder.Build(); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { + auto negate_computation = CreateNegateComputation(); + EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); +} + +TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { + // Create computation which calls one other computation. + auto negate_computation = CreateNegateComputation(); + auto map_computation = CreateMapComputation(negate_computation.get()); + EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); + EXPECT_EQ(map_computation->MakeEmbeddedComputationsList().front(), + negate_computation.get()); +} + +TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { + // Create computations with a diamond-shaped callgraph. + auto negate_computation = CreateNegateComputation(); + auto map1_computation = CreateMapComputation(negate_computation.get()); + auto map2_computation = CreateMapComputation(negate_computation.get()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto map1 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); + auto map2 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); + auto computation = builder.Build(); + + auto embedded_computations = computation->MakeEmbeddedComputationsList(); + EXPECT_EQ(3, embedded_computations.size()); + // GetEmbeddedComputations returns a post order of the embedded computations, + // so the negate computation must come first. + EXPECT_EQ(negate_computation.get(), *embedded_computations.begin()); + EXPECT_MATCH(testing::ListToVec(embedded_computations), + testing::UnorderedMatcher( + negate_computation.get(), map1_computation.get(), + map2_computation.get())); +} + +TEST_F(HloComputationTest, PostOrderSingleton) { + // Test GetInstructionPostOrder for a computation with one instruction. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto computation = builder.Build(); + + EXPECT_EQ(computation->MakeInstructionPostOrder().front(), constant); +} + +TEST_F(HloComputationTest, PostOrderSimple) { + // Test GetInstructionPostOrder for a computation with a chain of + // instructions. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); + auto negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); + auto computation = builder.Build(); + + EXPECT_MATCH( + testing::ListToVec( + computation->MakeInstructionPostOrder()), + testing::OrderedMatcher(constant, negate1, negate2)); +} + +TEST_F(HloComputationTest, PostOrderTrace) { + // Test GetInstructionPostOrder for a computation with a trace instruction. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); + auto trace = + builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); + auto negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); + auto computation = builder.Build(); + + // Trace instructions should be at the end of the sort. + EXPECT_MATCH(testing::ListToVec( + computation->MakeInstructionPostOrder()), + testing::OrderedMatcher(constant, negate1, + negate2, trace)); +} + +TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { + // Test GetInstructionPostOrder for a computation with multiple instructions + // which are not connected. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto computation = builder.Build(); + + EXPECT_MATCH(testing::ListToVec( + computation->MakeInstructionPostOrder()), + testing::UnorderedMatcher( + constant1, constant2, constant3, constant4)); +} + +TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { + // Test GetInstructionPostOrder for a computation with multiple instructions + // which are not connected. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant2, constant3)); + auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant3)); + auto computation = builder.Build(); + + auto post_order = computation->MakeInstructionPostOrder(); + EXPECT_EQ(6, post_order.size()); + EXPECT_MATCH(testing::ListToVec(post_order), + testing::UnorderedMatcher( + constant1, constant2, constant3, add1, add2, add3)); +} + +TEST_F(HloComputationTest, VisitWithMultipleRoots) { + // Test that Accept visits all instructions in the computation even if the + // computation has multiple roots (dead code). + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + // Add three disconnected add expressions. + builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + constant1, constant2)); + builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + constant2, constant3)); + builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + constant1, constant3)); + auto computation = builder.Build(); + + // Visitor which keeps track of which instructions have been visited. + class TestVisitor : public DfsHloVisitorWithDefault { + public: + explicit TestVisitor(HloComputation* computation) + : computation_(computation) {} + + Status DefaultAction(HloInstruction* hlo_instruction) override { + EXPECT_EQ(0, visited_set_.count(hlo_instruction)); + visited_set_.insert(hlo_instruction); + last_visited_ = hlo_instruction; + return Status::OK(); + } + + Status FinishVisit(HloInstruction* root) override { + EXPECT_EQ(computation_->root_instruction(), root); + ++finish_visit_calls_; + return Status::OK(); + } + + HloComputation* computation_; + std::set visited_set_; + int64 finish_visit_calls_ = 0; + HloInstruction* last_visited_ = nullptr; + }; + + TestVisitor visitor(computation.get()); + EXPECT_IS_OK(computation->Accept(&visitor)); + + EXPECT_EQ(6, visitor.visited_set_.size()); + EXPECT_EQ(1, visitor.finish_visit_calls_); + EXPECT_EQ(computation->root_instruction(), visitor.last_visited_); +} + +TEST_F(HloComputationTest, DeepCopyArray) { + // Test that DeepCopyInstruction properly copies an array. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + auto computation = builder.Build(); + + auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); + + EXPECT_EQ(HloOpcode::kCopy, copy->opcode()); + EXPECT_EQ(constant, copy->operand(0)); +} + +TEST_F(HloComputationTest, DeepCopyTuple) { + // Test that DeepCopyInstruction properly copies a tuple. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + + auto computation = builder.Build(); + + auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); + + EXPECT_EQ(HloOpcode::kTuple, tuple_copy->opcode()); + EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(0)->opcode()); + const HloInstruction* gte0 = tuple_copy->operand(0)->operand(0); + EXPECT_EQ(HloOpcode::kGetTupleElement, gte0->opcode()); + EXPECT_EQ(0, gte0->tuple_index()); + EXPECT_EQ(tuple, gte0->operand(0)); + + EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(1)->opcode()); + const HloInstruction* gte1 = tuple_copy->operand(1)->operand(0); + EXPECT_EQ(HloOpcode::kGetTupleElement, gte1->opcode()); + EXPECT_EQ(1, gte1->tuple_index()); + EXPECT_EQ(tuple, gte1->operand(0)); +} + +TEST_F(HloComputationTest, CycleDetection) { + // Test whether the visitor can detect cycles in the graph. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); + auto computation = builder.Build(); + + // Add a control dependency to create a cycle. + ASSERT_IS_OK(computation->AddControlDependency(add, negate)); + + const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; + auto visit_status = computation->Accept(visitor); + ASSERT_FALSE(visit_status.ok()); + ASSERT_MATCH(visit_status.error_message(), + testing::ContainsRegex("cycle is detecte")); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc new file mode 100644 index 0000000000..1b2a955f39 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -0,0 +1,350 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { + const auto& shape = hlo_instruction->shape(); + // For element-wise operations, the number of computations is the same as the + // number of elements in the output shape. + auto computation_count = ShapeUtil::ElementsIn(shape); + auto opcode = hlo_instruction->opcode(); + // We treat the two opcodes (kExp, kPower) as transcendental operations. + if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) { + transcendental_count_ += computation_count; + } else { + // Note: transcendental operations are considered a separate category from + // FLOPs. + hlo_to_flop_count_[hlo_instruction] = computation_count; + flop_count_ += computation_count; + } + return Status::OK(); +} + +Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, + HloOpcode opcode, + HloInstruction* operand) { + return HandleElementwiseOp(hlo); +} + +Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseOp(hlo); +} + +Status HloCostAnalysis::HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseOp(compare); +} + +Status HloCostAnalysis::HandleClamp(HloInstruction* clamp, + HloInstruction* min_instruction, + HloInstruction* arg_instruction, + HloInstruction* max_instruction) { + return HandleElementwiseOp(clamp); +} + +Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleConstant(HloInstruction* constant, + const Literal& literal) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleSelect(HloInstruction* select, + HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleReverse(HloInstruction* reverse, + HloInstruction* operand_instruction) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleSlice(HloInstruction* slice, + HloInstruction* operand_instruction) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleDynamicSlice( + HloInstruction* slice, + tensorflow::gtl::ArraySlice operands) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update, HloInstruction* operand, + HloInstruction* update, HloInstruction* start_indices) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleConvert(HloInstruction* convert, + HloInstruction* operand) { + flop_count_ += ShapeUtil::ElementsIn(operand->shape()); + return Status::OK(); +} + +Status HloCostAnalysis::HandleCopy(HloInstruction* copy, + HloInstruction* operand) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleDot(HloInstruction* dot, + HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction) { + // We count an FMA operation as 2 floating point operations. + // Multiplying the sizes of lhs, rhs, and result produces the square of the + // number of FMAs during the computation. + auto fma_count = std::sqrt( + static_cast(ShapeUtil::ElementsIn(lhs_instruction->shape())) * + ShapeUtil::ElementsIn(rhs_instruction->shape()) * + ShapeUtil::ElementsIn(dot->shape())); + flop_count_ += 2 * fma_count; + hlo_to_flop_count_[dot] = 2 * fma_count; + return Status::OK(); +} + +Status HloCostAnalysis::HandleInfeed(HloInstruction* infeed) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleMap( + HloInstruction* map, tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice /*static_operands*/) { + // Compute the cost of the user function. + HloInstruction* function_instruction = function->root_instruction(); + HloCostAnalysis visitor; + TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + + // Compute the cost of all elements for this Map operation. + auto element_count = ShapeUtil::ElementsIn(map->shape()); + flop_count_ += element_count * visitor.flop_count(); + transcendental_count_ += element_count * visitor.transcendental_count(); + return Status::OK(); +} + +Status HloCostAnalysis::HandleReduce( + HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { + // Compute the cost of the user function. + HloInstruction* function_instruction = function->root_instruction(); + HloCostAnalysis visitor; + TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + + // Compute the cost of all elements for this Reduce operation. + auto reduction_count = ShapeUtil::ElementsIn(arg->shape()) - + ShapeUtil::ElementsIn(reduce->shape()); + flop_count_ += reduction_count * visitor.flop_count(); + transcendental_count_ += reduction_count * visitor.transcendental_count(); + return Status::OK(); +} + +Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, + const Window& window, + HloComputation* function) { + // Compute the cost of the user function. + HloInstruction* function_instruction = function->root_instruction(); + HloCostAnalysis visitor; + TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + + // Compute the cost of all elements for this ReduceWindow operation. For each + // output element, (window_size - 1) number of user computations are applied. + auto output_size = ShapeUtil::ElementsIn(reduce_window->shape()); + int64 window_size = 1; + for (const auto& dimension : window.dimensions()) { + window_size *= dimension.size(); + } + flop_count_ += output_size * (window_size - 1) * visitor.flop_count(); + transcendental_count_ += + output_size * (window_size - 1) * visitor.transcendental_count(); + return Status::OK(); +} + +Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { + // Compute the cost of the select and scatter function. + HloInstruction* select = instruction->select()->root_instruction(); + HloCostAnalysis select_visitor; + TF_RETURN_IF_ERROR(select->Accept(&select_visitor)); + HloInstruction* scatter = instruction->scatter()->root_instruction(); + HloCostAnalysis scatter_visitor; + TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor)); + + // Compute the cost of all elements for this operation. For each scatter + // source element, (window_size - 1) number of select computations and 1 + // scatter computation are applied. + const auto source = instruction->operand(1); + const auto source_element_count = ShapeUtil::ElementsIn(source->shape()); + int64 window_size = 1; + for (const auto& dimension : instruction->window().dimensions()) { + window_size *= dimension.size(); + } + flop_count_ += + source_element_count * ((window_size - 1) * select_visitor.flop_count() + + scatter_visitor.flop_count()); + transcendental_count_ += + source_element_count * + ((window_size - 1) * select_visitor.transcendental_count() + + scatter_visitor.transcendental_count()); + return Status::OK(); +} + +Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleBroadcast(HloInstruction* broadcast) { + return Status::OK(); +} + +Status HloCostAnalysis::HandlePad(HloInstruction* pad) { return Status::OK(); } + +Status HloCostAnalysis::HandleSend(HloInstruction* send) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleRecv(HloInstruction* recv) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleReshape(HloInstruction* reshape) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleTranspose(HloInstruction* transpose) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, + HloInstruction* lhs_instruction, + HloInstruction* rhs_instruction, + const Window& window) { + const auto& dnums = convolution->convolution_dimension_numbers(); + const int64 output_features = + convolution->shape().dimensions(dnums.feature_dimension()); + + // For each output element, we do one fma per element in the + // kernel at some given output feature index. + const int64 fmas_per_output_element = + ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; + const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); + const double hlo_flop_count = static_cast(output_elements) * + fmas_per_output_element * kFmaFlops; + flop_count_ += hlo_flop_count; + hlo_to_flop_count_[convolution] = hlo_flop_count; + return Status::OK(); +} + +Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { + // We assume 2 replicas, so that each output element is the sum of two input + // elements. + // + // TODO(b/33004697): Compute correct cost here, taking the actual number of + // replicas into account. + const double hlo_flop_count = ShapeUtil::ElementsIn(crs->shape()); + flop_count_ += hlo_flop_count; + hlo_to_flop_count_[crs] = hlo_flop_count; + return Status::OK(); +} + +Status HloCostAnalysis::HandleRng(HloInstruction* random, + RandomDistribution distribution) { + // TODO(b/26346211): Implement better estimates for the RNG cost, since the + // cost changes with the implementation and the distribution. For now, assume + // the cost of each RNG is same as a transcendental operation. + transcendental_count_ += ShapeUtil::ElementsIn(random->shape()); + return Status::OK(); +} + +Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { + // Fusion instruction itself does not contribute to computation. + return fusion->fused_expression_root()->Accept(this); +} + +Status HloCostAnalysis::HandleCall( + HloInstruction* call, tensorflow::gtl::ArraySlice operands, + HloComputation* computation) { + return Unimplemented("call"); +} + +Status HloCostAnalysis::HandleCustomCall( + HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) { + return Unimplemented("custom-call"); +} + +Status HloCostAnalysis::HandleSort(HloInstruction* sort, + HloInstruction* operand_instruction) { + // The cost of sort is implementation dependent, so cannot determine at HLO + // level. Maybe just assume the comparison based N*log(N) sorting? + // TODO(b/26346211): Implement the cost model for sort. + return Unimplemented("HandleSort"); +} + +Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while, + HloInstruction* init, + HloComputation* condition, + HloComputation* body) { + // Since the number of iterations of the while node is not statically + // determined, we cannot analyze the computation cost of a while node. + // TODO(b/26346211): Add cost analysis for while node. + return Unimplemented("HandleWhile"); +} + +Status HloCostAnalysis::FinishVisit(HloInstruction* root) { + return Status::OK(); +} + +double HloCostAnalysis::hlo_to_flop_count(const HloInstruction& hlo) const { + auto it = hlo_to_flop_count_.find(&hlo); + return it == hlo_to_flop_count_.end() ? 0.0 : it->second; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h new file mode 100644 index 0000000000..8bed07c07e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// HloCostAnalysis traverses an HLO graph and calculates the amount of +// computations required for the graph. Each HLO instruction handler provides +// the computation cost of the instruction, and the values are accumulated +// during the traversal for the entire graph. We treat normal floating point +// operations separately from transcendental operations. +class HloCostAnalysis : public DfsHloVisitor { + public: + HloCostAnalysis() = default; + ~HloCostAnalysis() override = default; + + Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, + HloInstruction* operand) override; + Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs) override; + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override; + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override; + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override; + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleInfeed(HloInstruction* infeed) override; + Status HandleRng(HloInstruction* random, + RandomDistribution distribution) override; + Status HandleReverse(HloInstruction* reverse, + HloInstruction* operand) override; + Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; + Status HandleParameter(HloInstruction* parameter) override; + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function_handle) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleCall(HloInstruction* call, + tensorflow::gtl::ArraySlice operands, + HloComputation* computation) override; + Status HandleCustomCall(HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) override; + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + Status HandleDynamicSlice( + HloInstruction* slice, + tensorflow::gtl::ArraySlice operands) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override; + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice static_operands) override; + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandlePad(HloInstruction* pad) override; + Status HandleReshape(HloInstruction* reshape) override; + Status HandleTranspose(HloInstruction* transpose) override; + Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, + HloComputation* condition, HloComputation* body) override; + Status FinishVisit(HloInstruction* root) override; + + // Returns the amount of computations in the graph. + double flop_count() { return flop_count_; } + double transcendental_count() { return transcendental_count_; } + + // Resolves the provided HLO instruction to a flop count, or 0 if the HLO was + // not found to have a flop count in the analysis. + double hlo_to_flop_count(const HloInstruction& hlo) const; + + private: + // An FMA counts as two floating point operations in these analyses. + static constexpr int64 kFmaFlops = 2; + + // Utility function to handle all element-wise operations. + Status HandleElementwiseOp(HloInstruction* hlo_instruction); + + // Mapping from HLO instructions to the flop count we computed for them in the + // course of the graph analysis. + std::map hlo_to_flop_count_; + + // The number of floating point operations in the graph. + double flop_count_ = 0; + + // The number of transcendental operations in the graph. + double transcendental_count_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc new file mode 100644 index 0000000000..776d40ac4d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/user_computation.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" + +namespace xla { +namespace { + +// This test suite tests the HLO cost analysis by first building a computation +// using the client computation builder and running the HloCostAnalysis that +// returns the number of floating point and transcendental operations in the +// graph. We test both individual HLO operations as well as a mixed graph. +class HloCostAnalysisTest : public ::testing::Test { + protected: + HloCostAnalysisTest() + : client_(ClientLibrary::LocalClientOrDie()), + // Accessing service instance is required for the unit tests to enable + // whitebox acccesses to the user computation built from the client, + // as shown in the BuildHloGraph functions below. + service_(static_cast(ClientLibrary::GetXlaService( + static_cast(client_)->platform()))), + computation_tracker_(service_->computation_tracker()) { + // Create a computation for a unary user function: x => exp(x + 0.5) + { + ComputationBuilder builder(client_, "add_and_exp"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto half = builder.ConstantR0(0.5); + builder.Exp(builder.Add(x, half)); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + add_and_exp_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a binary user function: (x, y) => x + y + { + ComputationBuilder builder(client_, "add"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Add(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + add_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) + { + ComputationBuilder builder(client_, "sigmoid"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = builder.ConstantR0(1.0); + builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + sigmoid_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a binary max function: (x, y) => max (x, y) + { + ComputationBuilder builder(client_, "max"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Max(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + max_ = computation_status.ConsumeValueOrDie(); + } + + // Create a computation for a binary GT function: (x, y) => x > y + { + ComputationBuilder builder(client_, "gt"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Gt(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + gt_ = computation_status.ConsumeValueOrDie(); + } + } + + // Build HLO graph from the given builder and return the HLO module. + std::unique_ptr BuildHloGraph(ComputationBuilder* builder) { + auto computation_status = builder->Build(); + TF_CHECK_OK(computation_status.status()); + auto computation = computation_status.ConsumeValueOrDie(); + auto user_computation_status = + computation_tracker_.Resolve(computation.handle()); + TF_CHECK_OK(user_computation_status.status()); + auto user_computation = user_computation_status.ConsumeValueOrDie(); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + return std::move( + computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie()); + } + + Client* client_; + Service* service_; + const ComputationTracker& computation_tracker_; + + // User computations used for higher order operations (e.g., Map, Reduce). + Computation add_; + Computation add_and_exp_; + Computation sigmoid_; + Computation max_; + Computation gt_; +}; + +TEST_F(HloCostAnalysisTest, MatrixMultiply) { + ComputationBuilder builder(client_, "matrix_multiply"); + auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + auto result = builder.Dot(lhs, rhs); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5); +} + +TEST_F(HloCostAnalysisTest, Map) { + ComputationBuilder builder(client_, "map"); + auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); + auto result = builder.Map({input}, add_and_exp_); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // add contributes to 10 flops and exp contributes to 10 transcendental ops. + EXPECT_EQ(analysis.flop_count(), 10); + EXPECT_EQ(analysis.transcendental_count(), 10); +} + +TEST_F(HloCostAnalysisTest, Convolution) { + ComputationBuilder builder(client_, "convolution"); + auto input = builder.Parameter( + 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = builder.Parameter( + 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x1x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3); +} + +TEST_F(HloCostAnalysisTest, Reduce) { + ComputationBuilder builder(client_, "reduce"); + auto input = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + auto result = + builder.Reduce(input, builder.ConstantR0(0.0f), add_, {1}); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Subtracting the output size from the input size gives the number of + // reduction operations performed. + EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10); +} + +TEST_F(HloCostAnalysisTest, ReduceWindow) { + ComputationBuilder builder(client_, "reduce_window"); + auto input = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, + {4, 5}, {4, 5}, Padding::kValid); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Each of [2x4] output elements are generated from reducing [4x5] elements. + EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1)); +} + +TEST_F(HloCostAnalysisTest, SelectAndScatter) { + ComputationBuilder builder(client_, "select_and_scatter"); + auto operand = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + auto source = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); + auto result = + builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, + source, builder.ConstantR0(0), add_); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Each of [2x4] source elements computes its destination from reducing [4x5] + // elements followed by the scatter computation. + EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1)); +} + +TEST_F(HloCostAnalysisTest, Broadcast) { + ComputationBuilder b(client_, "broadcast"); + b.Broadcast(b.ConstantR0(42), {10, 7}); + auto hlo_module = BuildHloGraph(&b); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + EXPECT_EQ(analysis.flop_count(), 0); +} + +// Calculates the computation cost of a graph with more than one HLO node. +TEST_F(HloCostAnalysisTest, FullyConnectedForward) { + ComputationBuilder builder(client_, "fully_connected_forward"); + auto input = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); + auto weight = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); + auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); + // sigmoid(input * weight + bias) + auto result = builder.Map( + {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // 1000 FMAs from matrix multiplication, 200 flops from bias addition, + // 600 flops from sigmoid, and 200 transcendental ops from sigmoid. + EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200); + EXPECT_EQ(analysis.transcendental_count(), 200); +} + +TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { + HloCostAnalysis conv_analysis; + { + ComputationBuilder builder(client_, "conv_looking_matmul"); + auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "input"); + auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "weights"); + builder.Conv(lhs, rhs, {1, 1}, Padding::kSame); + auto hlo_module = BuildHloGraph(&builder); + ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( + &conv_analysis)); + } + + HloCostAnalysis matmul_analysis; + { + ComputationBuilder builder(client_, "matmul"); + auto lhs = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); + auto rhs = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); + builder.Dot(lhs, rhs); + auto hlo_module = BuildHloGraph(&builder); + ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( + &matmul_analysis)); + } + + EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count()); +} + +// Note that we still expect that any given operation won't overflow 2^64 FLOPs, +// just that the sum total may. +TEST_F(HloCostAnalysisTest, TotalOverflowsInt64) { + HloCostAnalysis matmul_analysis; + { + ComputationBuilder builder(client_, "matmul"); + auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {1, 1LL << 62}), + "input"); + auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {1LL << 62, 1}), + "weights"); + auto a = builder.Dot(lhs, rhs); + auto b = builder.Dot(a, lhs); + builder.Dot(b, rhs); + auto hlo_module = BuildHloGraph(&builder); + ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( + &matmul_analysis)); + } + + LOG(INFO) << matmul_analysis.flop_count(); + EXPECT_GT(matmul_analysis.flop_count(), std::numeric_limits::max()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc new file mode 100644 index 0000000000..7c28ff9da1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_cse.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +// Find and combine identical constants. Constants are identical if they have +// the same type and value. +bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { + bool changed = false; + + // Map from ShortDebugString of the layoutless shape of the constant to the + // set of constant instructions with that shape. Layoutless shape is used to + // bin possible common constants together to reduce number of constant + // comparisons. If we end up having too many constant comparisons, a more + // precise binning might have to be used. + std::multimap constants; + + auto inst_it = computation->instructions().begin(); + while (inst_it != computation->instructions().end()) { + HloInstruction* instruction = inst_it->get(); + + // Advance list iterator before loop body because iterator may be + // invalidated due to deletion. + ++inst_it; + + if (instruction->opcode() == HloOpcode::kConstant) { + Shape shape = instruction->shape(); + if (!is_layout_sensitive) { + LayoutUtil::ClearLayout(&shape); + } + string shape_string = shape.ShortDebugString(); + + // Compare against all constants with the same shape + auto range = constants.equal_range(shape_string); + HloInstruction* match = nullptr; + for (auto it = range.first; it != range.second; ++it) { + if (LiteralUtil::Equal(instruction->literal(), it->second->literal())) { + match = it->second; + break; + } + } + if (match == nullptr) { + constants.emplace(shape_string, instruction); + } else { + // Match found, replace this instruction with the one in the multimap. + computation->ReplaceUsesOfInstruction(instruction, match); + computation->RemoveInstruction(instruction); + changed = true; + } + } + } + + return changed; +} + +} // namespace + +StatusOr HloCSE::Run(HloModule* module) { + bool changed = false; + for (auto& computation : module->computations()) { + changed |= CombineConstants(computation.get(), is_layout_sensitive_); + + std::list post_order = + computation->MakeInstructionPostOrder(); + std::set removed_instructions; + for (auto instruction : post_order) { + // If the instruction has already been removed by CSE skip over it. + if (removed_instructions.count(instruction) > 0 || + instruction->operand_count() == 0) { + continue; + } + + // An instruction is considered to be equivalent to another only if they + // share the exact same set of operands. So to find equivalent + // instructions, we just search among instructions which share operand(0) + // of this instruction. + const HloInstruction* operand = instruction->operand(0); + + std::vector equivalent_instructions; + for (HloInstruction* user : operand->users()) { + if (user != instruction && user->Identical(*instruction) && + (!is_layout_sensitive_ || + ShapeUtil::Equal(user->shape(), instruction->shape()))) { + equivalent_instructions.push_back(user); + } + } + + // Replace all equivalent instructions with this instruction. + for (HloInstruction* equivalent_instruction : equivalent_instructions) { + computation->ReplaceUsesOfInstruction(equivalent_instruction, + instruction); + computation->RemoveInstruction(equivalent_instruction); + removed_instructions.insert(equivalent_instruction); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h new file mode 100644 index 0000000000..5b8b82462a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// A pass which performs common-subexpression elimination. Identical constants +// and identical instructions with the same operands are commoned. The pass +// iterates over the instructions in topological order which enables the pass to +// find arbitrarily large common expressions. +class HloCSE : public HloPass { + public: + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + explicit HloCSE(bool is_layout_sensitive) + : HloPass("cse"), is_layout_sensitive_(is_layout_sensitive) {} + ~HloCSE() override {} + + // Run CSE on the given module. Returns whether the module was changed (common + // subexpressions were found and eliminated). + StatusOr Run(HloModule* module) override; + + private: + bool is_layout_sensitive_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc new file mode 100644 index 0000000000..ec8161f55f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -0,0 +1,428 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_cse.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloCseTest : public HloTestBase { + protected: + HloCseTest() {} +}; + +TEST_F(HloCseTest, CombineTwoConstants) { + // Test that two identical constants are commoned. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); + HloInstruction* constant = computation->instructions().begin()->get(); + EXPECT_EQ(42.0f, LiteralUtil::Get(constant->literal(), {})); + + auto result = ExecuteAndTransfer(std::move(module), {}); + auto expected = LiteralUtil::CreateR0(84.0); + LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); +} + +TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { + // Test that two identical constants with different layouts are commoned if + // the pass is not layout sensitive. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + /*minor_to_major=*/{0, 1}))); + auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + /*minor_to_major=*/{1, 0}))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + EXPECT_NE(add->operand(0), add->operand(1)); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); + EXPECT_EQ(add->operand(0), add->operand(1)); + + auto result = ExecuteAndTransfer(std::move(module), {}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); +} + +TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { + // Test that two identical constants with different layouts are *not* commoned + // if the pass is layout sensitive. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + /*minor_to_major=*/{0, 1}))); + auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + /*minor_to_major=*/{1, 0}))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(constant1, add->operand(0)); + EXPECT_EQ(constant2, add->operand(1)); + + HloCSE cse(/*is_layout_sensitive=*/true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(constant1, add->operand(0)); + EXPECT_EQ(constant2, add->operand(1)); + + auto result = ExecuteAndTransfer(std::move(module), {}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); +} + +TEST_F(HloCseTest, ConstantsSameValueDifferentType) { + // Test that constants with the same value but different type are *not* + // commoned. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + // Duplicate the float constant to verify something happens. + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(7, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + + EXPECT_EQ(6, computation->instruction_count()); +} + +TEST_F(HloCseTest, NonscalarConstants) { + // Test that identical nonscalar constants are merged. + auto builder = HloComputation::Builder(TestName()); + auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + // Create a constant which has the same shape but a different value. + auto uncommon_constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + + // Tie the constants together with a tuple. This makes it easier to refer to + // the constant instructions via their use. + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + {common_constant1, common_constant2, uncommon_constant})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(4, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); + + EXPECT_EQ(tuple->operand(0), tuple->operand(1)); + EXPECT_EQ(uncommon_constant, tuple->operand(2)); + EXPECT_TRUE(tuple->operand(0) == common_constant1 || + tuple->operand(0) == common_constant2); +} + +TEST_F(HloCseTest, IdenticalInstructions) { + // Test that three identical instructions are commoned. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(5, computation->instruction_count()); + EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_NE(tuple->operand(1), tuple->operand(2)); + EXPECT_NE(tuple->operand(0), tuple->operand(2)); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(tuple->operand(0), tuple->operand(1)); + EXPECT_EQ(tuple->operand(1), tuple->operand(2)); +} + +TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { + // Test that two identical instructions with different layouts are *not* + // commoned if the pass is layout sensitive. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + + auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(4, computation->instruction_count()); + EXPECT_NE(tuple->operand(0), tuple->operand(1)); + + HloCSE cse(/*is_layout_sensitive=*/true); + EXPECT_FALSE(cse.Run(&module).ValueOrDie()); + + EXPECT_EQ(4, computation->instruction_count()); + EXPECT_NE(tuple->operand(0), tuple->operand(1)); +} + +TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { + // Test that two identical instructions with different layouts are commoned if + // the pass is layout insensitive. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + + auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(4, computation->instruction_count()); + EXPECT_NE(tuple->operand(0), tuple->operand(1)); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(tuple->operand(0), tuple->operand(1)); +} + +TEST_F(HloCseTest, IdenticalExpressions) { + // Test that two identical expressions are commoned. Build the following + // computation: + // + // constant = 42.0 + // negate1 = neg(constant) + // exp1 = exp(constant) + // add1 = add(negate1, exp1) + // negate2 = neg(constant) + // exp2 = exp(constant) + // add2 = add(negate2, exp2) + // tuple = tuple(add1, add2) + // + // The *1 instructions should be merged with the *2 instructions. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + + auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kAdd, negate1, exp1)); + + auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kExp, constant)); + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kAdd, negate2, exp2)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(8, computation->instruction_count()); + EXPECT_NE(tuple->operand(0), tuple->operand(1)); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + + EXPECT_EQ(5, computation->instruction_count()); + EXPECT_EQ(tuple->operand(0), tuple->operand(1)); + EXPECT_EQ(HloOpcode::kAdd, tuple->operand(0)->opcode()); +} + +TEST_F(HloCseTest, DoNotCombineRng) { + // Test that two RNG ops are not commoned. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( + ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, + {constant1, constant2})); + auto rng2 = builder.AddInstruction(HloInstruction::CreateRng( + ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, + {constant1, constant2})); + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, rng1, rng2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + uint32 count_before = computation->instruction_count(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + uint32 count_after = computation->instruction_count(); + EXPECT_EQ(count_before, count_after); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kRng); + EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kRng); + EXPECT_NE(root->operand(0), root->operand(1)); +} + +// TODO(b/28245743): Handle impure functions correctly in CSE. +TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { + // Test that two calls to an impure function are not commoned. RNG + // is the source of the impurity. + + auto module = MakeUnique(TestName()); + + // rng_function is an impure function because it does RNG. + HloComputation* rng_function = nullptr; + { + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName() + "_rng_fun"); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "param")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, rng, param)); + rng_function = module->AddEmbeddedComputation(builder.Build()); + } + + // Computation calls rng_function twice with the same parameter. + HloComputation* computation = nullptr; + { + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); + auto rng1 = builder.AddInstruction( + HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); + auto rng2 = builder.AddInstruction( + HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); + builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kAdd, rng1, rng2)); + computation = module->AddEntryComputation(builder.Build()); + } + + EXPECT_EQ(4, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(4, computation->instruction_count()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMap); + EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kMap); + EXPECT_NE(root->operand(0), root->operand(1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc new file mode 100644 index 0000000000..056bbc2473 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_dce.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr HloDCE::Run(HloModule* module) { + bool changed = false; + + for (auto& computation : module->computations()) { + std::unordered_set live_instructions; + TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( + [&live_instructions](HloInstruction* instruction) { + live_instructions.insert(instruction); + return Status::OK(); + })); + + // Remove any dead roots and their dead transitive operands. Collect them + // into a separate list first to avoid problems with iterating through the + // computation's instruction while simultaneously removing instructions. + std::vector dead_roots; + for (auto& instruction : computation->instructions()) { + if (instruction->user_count() == 0 && + live_instructions.count(instruction.get()) == 0 && + instruction->opcode() != HloOpcode::kParameter) { + dead_roots.push_back(instruction.get()); + } + } + + for (HloInstruction* dead_root : dead_roots) { + computation->RemoveInstructionAndUnusedOperands(dead_root); + changed = true; + } + } + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h new file mode 100644 index 0000000000..53ba352890 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes all dead instructions from each computation in the +// module. An instruction is dead if it is not reachable from the root. This +// pass does not remove dead parameter instructions as parameter instructions +// cannot be deleted, nor does the pass remove dead computations. +class HloDCE : public HloPass { + public: + HloDCE() : HloPass("dce") {} + ~HloDCE() override {} + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc new file mode 100644 index 0000000000..dcd9e00c56 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_dce.h" + +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloDceTest : public HloTestBase { + protected: + HloDceTest() {} +}; + +TEST_F(HloDceTest, NoDeadCode) { + // Verify that no dead code is removed from a computation with no dead code. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + +TEST_F(HloDceTest, DeadParameters) { + // Verify that dead parameters are not removed, but use of the dead parameters + // are. + auto builder = HloComputation::Builder(TestName()); + auto live_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "live_param")); + auto dead_param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "dead_param1")); + builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {}), "dead_param2")); + + // This is a dead negate instruction. + builder.AddInstruction(HloInstruction::CreateUnary( + dead_param1->shape(), HloOpcode::kNegate, dead_param1)); + + // This negate is not dead because it is the root. + builder.AddInstruction(HloInstruction::CreateUnary( + live_param->shape(), HloOpcode::kNegate, live_param)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(5, computation->instruction_count()); + EXPECT_EQ(1, dead_param1->user_count()); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(4, computation->instruction_count()); + EXPECT_EQ(0, dead_param1->user_count()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc new file mode 100644 index 0000000000..edba55f6cd --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { + +void HloExecutionProfile::AddProfileResult(const HloInstruction* hlo, + uint64 cycles_taken) { + hlo_to_cycles_taken_[hlo] = cycles_taken; +} + +uint64 HloExecutionProfile::GetProfileResult(const HloInstruction& hlo) const { + auto iter = hlo_to_cycles_taken_.find(&hlo); + if (iter == hlo_to_cycles_taken_.end()) { + return 0; + } + return iter->second; +} + +string HloExecutionProfile::ToString( + const DeviceDescription& device_description, + const HloCostAnalysis& cost_analysis) const { + using Item = std::pair; + std::vector items(hlo_to_cycles_taken_.begin(), + hlo_to_cycles_taken_.end()); + auto custom_less = [](const Item& lhs, const Item& rhs) { + return lhs.second > rhs.second; + }; + std::sort(items.begin(), items.end(), custom_less); + string result; + const int64 total_cycles = total_cycles_executed(); + double clock_rate_ghz = device_description.clock_rate_ghz(); + auto append_item = [&result, total_cycles, clock_rate_ghz]( + int64 cycles, int64 flops, const string& name) { + double nsecs = cycles / clock_rate_ghz; + tensorflow::strings::StrAppend( + &result, + tensorflow::strings::Printf( + "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %s", + cycles, cycles / static_cast(total_cycles) * 100, + nsecs / 1e3, + flops <= 0 ? "" : HumanReadableNumFlops(flops, nsecs).c_str(), + name.c_str())); + }; + tensorflow::strings::StrAppend( + &result, + tensorflow::strings::Printf("HLO execution profile: (%s @ f_nom)\n\t", + tensorflow::strings::HumanReadableElapsedTime( + total_cycles / clock_rate_ghz / 1e9) + .c_str())); + append_item(total_cycles, -1, "[total]"); + for (const auto& item : items) { + tensorflow::strings::StrAppend(&result, "\n\t"); + auto flops = item.first == nullptr + ? -1 + : cost_analysis.hlo_to_flop_count(*item.first); + string display = item.first == nullptr ? "" : item.first->ToString(); + append_item(item.second, flops, display); + } + return result; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h new file mode 100644 index 0000000000..6cc2079813 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloInstruction; + +// Describes how much time each HLO operation took. +// +// Each HloComputation takes a certain number of cycles. This class helps break +// down how much time each HLO took. +class HloExecutionProfile { + public: + using DeviceDescription = perftools::gputools::DeviceDescription; + + // Record how many cycles this HLO took to execute. + void AddProfileResult(const HloInstruction* hlo, uint64 cycles_taken); + + // Returns how many cycles this HLO took to execute. Profiling information + // may not be available for some instructions in which case zero is returned. + uint64 GetProfileResult(const HloInstruction& hlo) const; + + // Return the number of cycles this computation took to execute. + uint64 total_cycles_executed() const { return total_cycles_executed_; } + + // Record how many cycles the entire computation took to execute. + void set_total_cycles_executed(uint64 total_cycles_executed) { + total_cycles_executed_ = total_cycles_executed; + } + + // Returns a version of the execution profile suitable for performance + // debugging; e.g. emits cycle counts, execution time at the nominal device + // frequency, and the effective throughput given the provided cost_analysis + // for the operations. + string ToString(const DeviceDescription& device_description, + const HloCostAnalysis& cost_analysis) const; + + private: + // Contains a mapping from HLO to the number of cycles it took to execute it. + std::unordered_map hlo_to_cycles_taken_; + + // If non-empty, contains the total number of cycles this computation took to + // execute. + uint64 total_cycles_executed_ = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc new file mode 100644 index 0000000000..4865a8fb45 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -0,0 +1,507 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" + +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" + +using ::tensorflow::Env; +using ::tensorflow::WriteStringToFile; +using ::tensorflow::io::JoinPath; +using ::tensorflow::strings::Appendf; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; +using ::tensorflow::str_util::Join; + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +// Returns the dot graph identifier for the given instruction. +string InstructionId(const HloInstruction* instruction) { + return Printf("%lld", reinterpret_cast(instruction)); +} + +// Returns the dot graph identifier for the given computation. +string ComputationId(const HloComputation* computation) { + return Printf("%lld", reinterpret_cast(computation)); +} + +// Returns a compact string that represents the convolution dimension numbers. +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dim_numbers) { + return Printf("B@%lld,Z@%lld,KIZ@%lld,KOZ@%lld", + dim_numbers.batch_dimension(), dim_numbers.feature_dimension(), + dim_numbers.kernel_input_feature_dimension(), + dim_numbers.kernel_output_feature_dimension()); +} + +// Returns a compact string that represents the non-trivial fields in the window +// description. If there are no non-trivial fields, the empty string is +// returned. +string WindowToString(const Window& window) { + bool display_padding = false; + bool display_window_dilation = false; + bool display_base_dilation = false; + bool display_stride = false; + for (const WindowDimension& dimension : window.dimensions()) { + display_padding |= + dimension.padding_low() != 0 || dimension.padding_high() != 0; + display_window_dilation |= dimension.window_dilation() != 1; + display_base_dilation |= dimension.base_dilation() != 1; + display_stride |= dimension.stride() != 1; + } + std::vector pieces = {}; + if (display_padding) { + pieces.push_back("\\n"); + pieces.push_back("padding=["); + for (const WindowDimension& dimension : window.dimensions()) { + pieces.push_back(StrCat("(", dimension.padding_low(), ",", + dimension.padding_high(), ")")); + pieces.push_back(", "); + } + pieces.pop_back(); + pieces.push_back("]"); + } + // Make a convenient lambda that adds a simple int64 field in each + // WindowDimension. + auto add_field = [&pieces, &window]( + const string& label, + tensorflow::protobuf_int64 (WindowDimension::*member)() const) { + pieces.push_back("\\n"); + pieces.push_back(label + "=["); + for (const WindowDimension& dimension : window.dimensions()) { + pieces.push_back(StrCat(((&dimension)->*member)())); + pieces.push_back(", "); + } + pieces.pop_back(); + pieces.push_back("]"); + }; + if (display_window_dilation) { + add_field("window_dilation", &WindowDimension::window_dilation); + } + if (display_base_dilation) { + add_field("base_dilation", &WindowDimension::base_dilation); + } + if (display_stride) { + add_field("stride", &WindowDimension::stride); + } + return Join(pieces, ""); +} + +// Returns the dot graph edges and nodes for the given instruction sequence. +// Edges which extend between computations are added to the vector +// intercomputation_edges. This is necessary because graphviz does not render +// the graph properly unless these inter-computation edges appear after all +// subgraph statements. +string InstructionSequenceGraph( + const std::list>& instructions, + bool show_addresses, bool show_layouts, + std::vector* intercomputation_edges, + const HloExecutionProfile* hlo_execution_profile) { + string graph_body; + + // Create a single "record" node for the parameters. This node is a + // partitioned rectangle with one partition per parameter node. The keeps + // all the parameter instructions together. + std::vector param_instructions; + for (auto& instruction : instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_number = instruction->parameter_number(); + if (param_instructions.size() < param_number + 1) { + param_instructions.resize(param_number + 1, nullptr); + } + param_instructions[param_number] = instruction.get(); + } + } + string param_node_name; + if (!param_instructions.empty()) { + std::vector param_ports; + param_node_name = + StrCat("parameters_", InstructionId(param_instructions[0])); + for (auto& param : param_instructions) { + string label = StrCat(param->parameter_name(), "\\n", + ShapeUtil::HumanString(param->shape())); + if (show_addresses) { + Appendf(&label, "\\n[%p]", param); + } + if (show_layouts) { + StrAppend(&label, "\\nlayout=\\{", + Join(param->shape().layout().minor_to_major(), ","), "\\}"); + } + param_ports.push_back( + Printf("<%s> %s", InstructionId(param).c_str(), label.c_str())); + } + StrAppend(&graph_body, param_node_name, + " [shape=record,style=filled,fillcolor=\"lightblue1\",", + "label=\"{parameters | {", Join(param_ports, "|"), "}}\"];\n"); + } + + for (auto& instruction : instructions) { + string color = "peachpuff"; + string shape = "ellipse"; + string name = HloOpcodeString(instruction->opcode()); + if (HloOpcode::kFusion == instruction->opcode()) { + name += ": " + FusionKindString(instruction->fusion_kind()); + } + if (HloOpcode::kConvolution == instruction->opcode()) { + name += ":\\n" + ConvolutionDimensionNumbersToString( + instruction->convolution_dimension_numbers()) + + WindowToString(instruction->window()); + } + name += "\\n" + instruction->name(); + std::vector called_computations; + + // Pick different colors or shapes for instructions which are particularly + // expensive (eg, dot) and those which are unusual in some way or unique + // (eg, parameter). + switch (instruction->opcode()) { + // "Normal" instructions. Mostly cheap and elementwise. No call to + // embedded computations. In this case, use default color, shape and + // label. + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kConcatenate: + case HloOpcode::kConvert: + case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kIndex: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalNot: + case HloOpcode::kLogicalOr: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTuple: + case HloOpcode::kUpdate: + break; + + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + StrAppend(&name, "\\n", "dims={", Join(instruction->dimensions(), ","), + "}"); + break; + case HloOpcode::kGetTupleElement: + StrAppend(&name, "\\nindex=", instruction->tuple_index()); + break; + case HloOpcode::kRng: + StrAppend(&name, "\\n", + RandomDistribution_Name(instruction->random_distribution())); + break; + case HloOpcode::kConstant: + shape = "boxed"; + color = "palegreen"; + if (ShapeUtil::IsScalar(instruction->shape())) { + StrAppend(&name, "\\n", "value=", LiteralUtil::GetAsString( + instruction->literal(), {})); + } + break; + case HloOpcode::kBitcast: + case HloOpcode::kCopy: + color = "white"; + break; + case HloOpcode::kCall: + color = "tomato"; + break; + case HloOpcode::kCustomCall: + color = "tomato4"; + StrAppend(&name, "\\n", + "custom_call_target=", instruction->custom_call_target()); + break; + case HloOpcode::kDot: + color = "slateblue"; + break; + case HloOpcode::kSend: + color = "purple"; + break; + case HloOpcode::kRecv: + color = "orange"; + break; + case HloOpcode::kMap: + color = "palevioletred"; + break; + case HloOpcode::kParameter: + // A single record node is created for all the parameter nodes with a + // port for each parameter instruction. No need to emit anything in this + // case. + continue; + case HloOpcode::kReduce: + StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); + color = "lightsalmon"; + break; + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReduceWindow: + color = "lightsalmon"; + break; + case HloOpcode::kTrace: + color = "white"; + break; + case HloOpcode::kWhile: + color = "forestgreen"; + break; + case HloOpcode::kFusion: + color = "gray"; + break; + case HloOpcode::kConvolution: + color = "red"; + break; + case HloOpcode::kCrossReplicaSum: + color = "turquoise"; + break; + case HloOpcode::kInfeed: + color = "blue"; + break; + } + + // Create instruction node with appropriate label, shape, and color. + string label = + StrCat(name, "\\n", ShapeUtil::HumanString(instruction->shape())); + if (show_addresses) { + Appendf(&label, "\\n[%p]", instruction.get()); + } + if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) { + string layout_string; + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuples, emit the full shape because the layout of a tuple is not + // represented in a single Layout field. + layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); + } else { + layout_string = + Join(instruction->shape().layout().minor_to_major(), ","); + } + StrAppend(&label, "\\nlayout={", layout_string, "}"); + } + if (hlo_execution_profile != nullptr) { + auto hlo_cycles_executed = + hlo_execution_profile->GetProfileResult(*instruction); + auto total_cycles_executed = + hlo_execution_profile->total_cycles_executed(); + if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { + Appendf(&label, "\\n%% of cycles executed=%.2f", + (static_cast(hlo_cycles_executed) / + static_cast(total_cycles_executed)) * + 100); + } + } + Appendf(&graph_body, + "%s [label=\"%s\", shape=%s, style=filled, fillcolor=%s];\n", + InstructionId(instruction.get()).c_str(), label.c_str(), + shape.c_str(), color.c_str()); + + // Create edges from the instruction's operands to the instruction. + int64 operand_number = 0; + for (auto* operand : instruction->operands()) { + string src; + if (operand->opcode() == HloOpcode::kParameter) { + // If operand is a parameter, then select the proper partition (port) in + // the unified parameter node. + src = param_node_name + ":" + InstructionId(operand); + } else { + src = InstructionId(operand); + } + Appendf(&graph_body, "%s -> %s", src.c_str(), + InstructionId(instruction.get()).c_str()); + if (instruction->operand_count() > 1) { + Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]", + operand_number); + } + StrAppend(&graph_body, ";\n"); + ++operand_number; + } + + // Fusion nodes are handled specially because they contain nested + // expressions. + if (instruction->opcode() == HloOpcode::kFusion) { + string cluster_name = + StrCat("cluster_", InstructionId(instruction.get())); + StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); + StrAppend(&graph_body, + "label=\"fused expression\";\nstyle=filled;\n" + "color=lightgrey;\n"); + StrAppend(&graph_body, InstructionSequenceGraph( + instruction->fused_instructions(), + show_addresses, show_layouts, + intercomputation_edges, hlo_execution_profile), + "}\n"); + string fusion_edge = + StrCat(InstructionId(instruction->fused_expression_root()), " -> ", + InstructionId(instruction.get()), + " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name, + " ];\n"); + intercomputation_edges->push_back(fusion_edge); + } else { + // Add a dotted edge between the instruction and any computations that the + // instruction calls. + for (auto* computation : instruction->MakeCalledComputationsSet()) { + string cluster_name = StrCat("cluster_", ComputationId(computation)); + string call_edge = Printf( + "%s -> %s [ style=dashed; ltail=%s ];\n", + InstructionId(computation->root_instruction()).c_str(), + InstructionId(instruction.get()).c_str(), cluster_name.c_str()); + intercomputation_edges->push_back(call_edge); + } + } + } + return graph_body; +} + +string ComputationToDotGraph(const HloComputation& computation, + const string& label, bool show_addresses, + bool show_layouts, + const HloExecutionProfile* hlo_execution_profile) { + string graph_label = StrCat(label, "\\n", computation.name()); + if (hlo_execution_profile != nullptr) { + auto cycles = hlo_execution_profile->total_cycles_executed(); + Appendf(&graph_label, "\\ntotal cycles = %lld (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles).c_str()); + } + string graph = + Printf("digraph G {\nrankdir=TB;\ncompound=true;\nlabel=\"%s\"\n", + graph_label.c_str()); + + // Emit embedded computations as subgraph clusters. + std::vector intercomputation_edges; + for (auto embedded : computation.MakeEmbeddedComputationsList()) { + string graph_body = InstructionSequenceGraph( + embedded->instructions(), show_addresses, show_layouts, + &intercomputation_edges, hlo_execution_profile); + Appendf(&graph, "subgraph cluster_%s {\nlabel=\"%s\";\n%s}\n", + ComputationId(embedded).c_str(), embedded->name().c_str(), + graph_body.c_str()); + } + StrAppend(&graph, + InstructionSequenceGraph(computation.instructions(), show_addresses, + show_layouts, &intercomputation_edges, + hlo_execution_profile)); + + // Edges between computations (subgraph clusters) must be emitted last for the + // graph to be rendered properly for some reason. + StrAppend(&graph, Join(intercomputation_edges, "\n"), "}\n"); + + return graph; +} + +tensorflow::mutex& RendererMutex() { + static tensorflow::mutex* mu = new tensorflow::mutex; + return *mu; +} + +std::map* GraphRenderers() { + static auto* graph_renderers = new std::map(); + return graph_renderers; +} + +GraphRendererInterface* GetGraphRenderer() { + tensorflow::mutex_lock lock(RendererMutex()); + auto* graph_renderers = GraphRenderers(); + auto it = graph_renderers->rbegin(); + CHECK(it != graph_renderers->rend()) << "No registered graph dumpers"; + return it->second; +} + +} // namespace + +Registrar::Registrar(GraphRendererInterface* dumper, int priority) { + tensorflow::mutex_lock lock(RendererMutex()); + auto* graph_renderers = GraphRenderers(); + graph_renderers->emplace(priority, dumper); +} + +namespace { + +class FileGraphRenderer : public GraphRendererInterface { + public: + string RenderGraph(const string& graph) override { + static std::atomic output_num(0); + legacy_flags::HloGraphDumperFlags* flags = + legacy_flags::GetHloGraphDumperFlags(); + string path = StrCat(flags->xla_hlo_dump_graph_path, "hlo_graph_", + output_num++, ".dot"); + tensorflow::Status status = + tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); + if (!status.ok()) { + LOG(WARNING) << "Saving HLO graph failed: " << status; + } + return path; + } +}; + +XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); + +} // namespace + +string DumpGraph(const HloComputation& computation, const string& label, + bool show_addresses, bool show_layouts, + const HloExecutionProfile* hlo_execution_profile) { + string graph = ComputationToDotGraph(computation, label, show_addresses, + show_layouts, hlo_execution_profile); + + string graph_url = GetGraphRenderer()->RenderGraph(graph); + LOG(INFO) << "computation " << computation.name() << " [" << label + << "]: " << graph_url; + return graph_url; +} + +void DumpText(const HloModule& module, const string& label, + const string& directory_path) { + Env* env = Env::Default(); + TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); + string prefix = StrCat(env->NowMicros()); + string path = JoinPath(directory_path, StrCat(prefix, "-", label, ".txt")); + TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); +} + +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h new file mode 100644 index 0000000000..45fd46352f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace hlo_graph_dumper { + +// Dumps a graph of the computation to the GraphViz server and returns +// a description of the rendered graph (e.g., a URL). +string DumpGraph(const HloComputation& computation, const string& label, + bool show_addresses, bool show_layouts, + const HloExecutionProfile* hlo_execution_profile = nullptr); + +// Dumps the HloModule::ToString() as a file into the provided directory path +// suffixed with the provided label. +void DumpText(const HloModule& module, const string& label, + const string& directory_path); + +// Abstract interface for classes that render DOT graphs. +class GraphRendererInterface { + public: + virtual ~GraphRendererInterface() = default; + + // Renders a DOT graph, returning a description of the rendered output + // (e.g., a URL) + virtual string RenderGraph(const string& graph) = 0; +}; + +// Graph renderers may be added using a registration mechanism, e.g.: +// XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) +// The renderer with the highest numeric priority value is used. + +#define XLA_REGISTER_GRAPH_RENDERER(factory, ...) \ + XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, __COUNTER__, ##__VA_ARGS__) + +// Internal implementation details below this point. + +// Class that registers a graph renderer. Higher-priority renders are chosen +// first. +class Registrar { + public: + Registrar(GraphRendererInterface* dumper, int priority); +}; + +#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ + static ::xla::hlo_graph_dumper::Registrar \ + XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)(new factory, \ + ##__VA_ARGS__) + +// __COUNTER__ must go through another macro to be properly expanded +#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_ + +} // namespace hlo_graph_dumper +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc new file mode 100644 index 0000000000..7ae0a995af --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -0,0 +1,1921 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +/* static */ std::unique_ptr HloInstruction::CreateParameter( + int64 parameter_number, const Shape& shape, const string& name) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); + instruction->parameter_number_ = parameter_number; + instruction->parameter_name_ = name; + instruction->name_ = "%" + name; + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateTrace( + const string& tag, HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); + instruction->operands_.push_back(operand); + instruction->literal_.reset(new Literal); + *instruction->literal_->mutable_u8s() += tag; + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateConstant( + std::unique_ptr literal) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); + instruction->literal_ = std::move(literal); + return instruction; +} + +/* static */ std::unique_ptr +HloInstruction::CreateGetTupleElement(const Shape& shape, + HloInstruction* operand, int64 index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); + instruction->tuple_index_ = index; + instruction->AppendOperand(operand); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateRng( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape)); + instruction->distribution_ = distribution; + instruction->shape_ = shape; + for (HloInstruction* param : parameters) { + instruction->AppendOperand(param); + } + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateNary( + const Shape& shape, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + if (opcode == HloOpcode::kCopy) { + // It is impossible to copy an opaque shape, we don't know how big it is. + CHECK(!ShapeUtil::IsOpaque(shape)); + } + auto instruction = WrapUnique(new HloInstruction(opcode, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateUnary( + const Shape& shape, HloOpcode opcode, HloInstruction* operand) { + // Only certain opcodes are supported with CreateUnary: opcodes of unary + // instructions with no auxiliary fields. + switch (opcode) { + case HloOpcode::kAbs: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kLogicalNot: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kSort: + case HloOpcode::kTanh: + break; + default: + LOG(FATAL) << "Invalid unary instruction opcode " + << HloOpcodeString(opcode); + } + return CreateNary(shape, opcode, {operand}); +} + +/* static */ std::unique_ptr HloInstruction::CreateBinary( + const Shape& shape, HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs) { + // Only certain opcodes are supported with CreateBinary: opcodes of binary + // instructions with no auxiliary fields. + switch (opcode) { + case (HloOpcode::kAdd): + case (HloOpcode::kDivide): + case (HloOpcode::kDot): + case (HloOpcode::kEq): + case (HloOpcode::kGe): + case (HloOpcode::kGt): + case (HloOpcode::kLe): + case (HloOpcode::kLt): + case (HloOpcode::kMaximum): + case (HloOpcode::kMinimum): + case (HloOpcode::kMultiply): + case (HloOpcode::kNe): + case (HloOpcode::kPower): + case (HloOpcode::kRemainder): + case (HloOpcode::kSubtract): + case (HloOpcode::kLogicalAnd): + case (HloOpcode::kLogicalOr): + break; + default: + LOG(FATAL) << "Invalid binary instruction opcode " + << HloOpcodeString(opcode); + } + return CreateNary(shape, opcode, {lhs, rhs}); +} + +/* static */ std::unique_ptr HloInstruction::CreateTernary( + const Shape& shape, HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs, HloInstruction* ehs) { + // Only certain opcodes are supported with CreateTernary: opcodes of ternary + // instructions with no auxiliary fields. + switch (opcode) { + case (HloOpcode::kClamp): + case (HloOpcode::kSelect): + break; + default: + LOG(FATAL) << "Invalid ternary instruction opcode " + << HloOpcodeString(opcode); + } + return CreateNary(shape, opcode, {lhs, rhs, ehs}); +} + +/* static */ std::unique_ptr HloInstruction::CreateVariadic( + const Shape& shape, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + CHECK_EQ(HloOpcode::kTuple, opcode); + return CreateNary(shape, opcode, operands); +} + +/* static */ std::unique_ptr HloInstruction::CreateMap( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands) { + CHECK(static_operands.empty()) << "static_operands not yet supported"; + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->to_apply_ = map_computation; + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateConvolve( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->window_ = MakeUnique(window); + instruction->convolution_dimension_numbers_ = + MakeUnique(dimension_numbers); + return instruction; +} + +/* static */ std::unique_ptr +HloInstruction::CreateCrossReplicaSum(const Shape& shape, + HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + instruction->AppendOperand(operand); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateInfeed( + const Shape& shape) { + return WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); +} + +/* static */ std::unique_ptr HloInstruction::CreateSend( + HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil())); + instruction->AppendOperand(operand); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateRecv( + const Shape& shape) { + return WrapUnique(new HloInstruction(HloOpcode::kRecv, shape)); +} + +/* static */ std::unique_ptr HloInstruction::CreateReverse( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); + instruction->AppendOperand(operand); + instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateWhile( + const Shape& shape, HloComputation* condition, HloComputation* body, + HloInstruction* init) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); + instruction->AppendOperand(init); + instruction->condition_ = condition; + instruction->body_ = body; + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateSlice( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); + instruction->AppendOperand(operand); + instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); + instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(start_indices); + instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(), + slice_sizes.end()); + return instruction; +} + +/* static */ std::unique_ptr +HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(update); + instruction->AppendOperand(start_indices); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateConcatenate( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->dimensions_.push_back(dimension); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateConvert( + const Shape& shape, HloInstruction* operand) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); + instruction->AppendOperand(operand); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateReduce( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); + instruction->AppendOperand(arg); + instruction->AppendOperand(init_value); + instruction->dimensions_.assign(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + instruction->to_apply_ = reduce_computation; + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateReduceWindow( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + const Window& window, HloComputation* reduce_computation) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(init_value); + instruction->to_apply_ = reduce_computation; + instruction->window_ = MakeUnique(window); + return instruction; +} + +/* static */ std::unique_ptr +HloInstruction::CreateSelectAndScatter( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(source); + instruction->AppendOperand(init_value); + instruction->select_ = select; + instruction->scatter_ = scatter; + instruction->window_ = MakeUnique(window); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateBroadcast( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); + instruction->AppendOperand(operand); + instruction->dimensions_.assign(broadcast_dimensions.begin(), + broadcast_dimensions.end()); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreatePad( + const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, + const PaddingConfig& padding_config) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(padding_value); + instruction->padding_config_ = MakeUnique(padding_config); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateReshape( + const Shape& shape, HloInstruction* operand) { + CHECK_EQ(ShapeUtil::ElementsIn(shape), + ShapeUtil::ElementsIn(operand->shape())); + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); + instruction->AppendOperand(operand); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateTranspose( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); + instruction->AppendOperand(operand); + instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateFusion( + const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); + instruction->fusion_kind_ = fusion_kind; + instruction->CloneAndFuseInternal(fused_root); + instruction->CheckFusionInstruction(); + return instruction; +} + +/* static */ std::unique_ptr +HloInstruction::CreateFusionForBackwardConvolution( + const Shape& shape, FusionKind fusion_kind, const Window& window, + const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) { + std::unique_ptr fusion = + CreateFusion(shape, fusion_kind, fused_root); + fusion->window_ = MakeUnique(window); + fusion->convolution_dimension_numbers_ = + MakeUnique(conv_dnums); + return fusion; +} + +HloInstruction* HloInstruction::FuseInstruction( + HloInstruction* instruction_to_fuse) { + CHECK_EQ(opcode_, HloOpcode::kFusion); + + // This fusion instruction must be a user of instruction_to_fuse. + CHECK_NE(0, instruction_to_fuse->users().count(this)); + HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse); + CheckFusionInstruction(); + return fused_instruction; +} + +HloInstruction* HloInstruction::CloneAndFuseInternal( + HloInstruction* instruction_to_fuse) { + CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK(instruction_to_fuse->IsFusable()); + + bool new_fusion_instruction = fused_instructions_.empty(); + fused_instructions_.emplace_back(instruction_to_fuse->Clone()); + HloInstruction* clone = fused_instructions_.back().get(); + clone->parent_fusion_instruction_ = this; + + if (new_fusion_instruction) { + fused_root_ = clone; + } else { + // instruction_to_fuse is necessarily an operand of the fusion instruction. + // After fusion this will no longer be the case. Remove the operand from the + // operand list and remove its corresponding fused parameter + // instruction. Renumber parameters as necessary to make parameter numbers + // consistent with their index in the fused_parameter_ vector. + CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) != + operands_.end()); + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + if (instruction_to_fuse == operands_[operand_num]) { + // replace the fused parameter instruction's uses with the clone. + HloInstruction* fused_parameter = fused_parameters_[operand_num]; + fused_parameter->ReplaceAllUsesWith(clone); + + // Remove the corresponding fused parameter and operand from their + // respective vectors. + fused_parameters_.erase(fused_parameters_.begin() + operand_num); + operands_.erase(operands_.begin() + operand_num); + + // Renumber fused parameter numbers to match the vector index. + while (operand_num < fused_parameters_.size()) { + fused_parameters_[operand_num]->parameter_number_ = operand_num; + operand_num++; + } + // Throw removed fused parameter instruction away. + auto inst_it = + std::find_if(fused_instructions_.begin(), fused_instructions_.end(), + [=](const std::unique_ptr& inst) { + return inst.get() == fused_parameter; + }); + CHECK(inst_it != fused_instructions_.end()); + fused_instructions_.erase(inst_it); + break; + } + } + // We've cloned instruction_to_fuse into this fusion instruction, so this + // fusion instruction is no longer a use of instruction_to_fuse. + instruction_to_fuse->RemoveUser(this); + } + + // Add each operand of the clone as an operand of the fusion instruction. A + // complication is that some clone operands may already be operands of the + // fusion instruction. + for (int64 operand_num = 0; operand_num < clone->operand_count(); + ++operand_num) { + HloInstruction* operand = clone->mutable_operand(operand_num); + + // See if this operand is already an operand of the fusion node. + CHECK_EQ(operands_.size(), fused_parameters_.size()); + HloInstruction* fused_param = nullptr; + for (int64 i = 0; i < operands_.size(); ++i) { + if (operands_[i] == operand) { + fused_param = fused_parameters_[i]; + break; + } + } + + if (fused_param == nullptr) { + // Clone's operand was not already an operand of the fusion + // instruction. Add it as an operand and add a corresponding fused + // parameter instruction. + int64 param_no = fused_parameters_.size(); + std::unique_ptr param_instruction = + CreateParameter(param_no, operand->shape(), "fusion_param"); + + param_instruction->parent_fusion_instruction_ = this; + fused_parameters_.push_back(param_instruction.get()); + fused_instructions_.push_back(std::move(param_instruction)); + AppendOperand(operand); + + fused_param = fused_instructions_.back().get(); + } + clone->ReplaceOperandWith(operand_num, fused_param); + } + + return clone; +} + +RandomDistribution HloInstruction::random_distribution() const { + CHECK_EQ(opcode_, HloOpcode::kRng); + return distribution_; +} + +namespace { + +// Adds any HloComputations this instruction calls directly to the given set. +void CalledComputationsInternal( + const HloInstruction& instruction, + std::set* called_computations) { + switch (instruction.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + called_computations->insert(instruction.to_apply()); + break; + case HloOpcode::kSelectAndScatter: + called_computations->insert(instruction.select()); + called_computations->insert(instruction.scatter()); + break; + case HloOpcode::kWhile: + called_computations->insert(instruction.while_condition()); + called_computations->insert(instruction.while_body()); + break; + case HloOpcode::kFusion: + for (const auto& fused_instruction : instruction.fused_instructions()) { + CalledComputationsInternal(*fused_instruction, called_computations); + } + break; + default: + break; + } +} + +} // namespace + +std::set HloInstruction::MakeCalledComputationsSet() const { + std::set called_computations; + CalledComputationsInternal(*this, &called_computations); + return called_computations; +} + +void HloInstruction::CheckFusionInstruction() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + + // All instructions owned by this fusion instruction must be fused, and the + // parent fusion instruction of the fused instructions must be 'this'. + for (auto& instruction : fused_instructions_) { + CHECK(instruction->IsFused()); + CHECK_EQ(this, instruction->fusion_instruction()); + } + + // Fused root instruction and fused parameters must all be owned by the fusion + // instruction. + bool root_owned = false; + std::vector parameter_owned(fused_parameters_.size(), false); + for (auto& instruction : fused_instructions_) { + if (fused_root_ == instruction.get()) { + CHECK(!root_owned); + root_owned = true; + } + for (int i = 0; i < fused_parameters_.size(); ++i) { + if (fused_parameters_[i] == instruction.get()) { + CHECK(!parameter_owned[i]); + parameter_owned[i] = true; + } + } + } + CHECK(root_owned); + // Make sure all the parameter_owned entries are set + for (int i = 0; i < parameter_owned.size(); i++) { + CHECK(parameter_owned[i]); + } + + // Fused root must have no users. + CHECK_EQ(0, fused_root_->user_count()); + + // All uses of fused instructions must be in the fusion instruction, and every + // non-root instruction must have at least one use. + for (auto& instruction : fused_instructions_) { + if (instruction.get() != fused_root_) { + CHECK_GT(instruction->user_count(), 0); + for (auto& user : instruction->users()) { + CHECK(user->IsFused()); + CHECK_EQ(this, user->fusion_instruction()); + } + } + } + + // Fused parameter instructions must be numbered contiguously and match up + // (shapes compatible) with their respective operand. + CHECK_EQ(operands_.size(), fused_parameters_.size()); + std::vector parameter_numbers(fused_parameters_.size(), false); + for (auto fused_param : fused_parameters_) { + int64 param_no = fused_param->parameter_number(); + CHECK_GE(param_no, 0); + CHECK_LT(param_no, fused_parameters_.size()); + CHECK(!parameter_numbers[param_no]); + parameter_numbers[param_no] = true; + CHECK(ShapeUtil::Compatible(fused_param->shape(), + operands_[param_no]->shape())); + } + // Make sure all the parameter_numbers entries were seen + for (int i = 0; i < parameter_numbers.size(); i++) { + CHECK(parameter_numbers[i]); + } + + // Operands must be distinct. + std::set operand_set(operands_.begin(), operands_.end()); + CHECK_EQ(operand_set.size(), operands_.size()); +} + +/* static */ std::unique_ptr HloInstruction::CreateCall( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* computation) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->to_apply_ = computation; + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->custom_call_target_ = custom_call_target.ToString(); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (auto element : elements) { + element_shapes.push_back(element->shape()); + } + Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); + return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); +} + +std::unique_ptr HloInstruction::CloneWithNewOperands( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + // Explicitly call the factory for the instruction type. This is more robust + // in the face of code changes than copying fields explicitly. This also + // properly sets the user fields of the operands. + switch (opcode_) { + // Unary ops. + case HloOpcode::kAbs: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kLogicalNot: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kSort: + case HloOpcode::kTanh: + CHECK_EQ(operands.size(), 1); + return CreateUnary(shape, opcode_, operands[0]); + // Binary ops. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + case HloOpcode::kDot: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalOr: + CHECK_EQ(operands.size(), 2); + return CreateBinary(shape, opcode_, operands[0], operands[1]); + // Ternary ops. + case HloOpcode::kClamp: + case HloOpcode::kSelect: + CHECK_EQ(operands.size(), 3); + return CreateTernary(shape, opcode_, operands[0], operands[1], + operands[2]); + // Other supported ops. + case HloOpcode::kBroadcast: + CHECK_EQ(operands.size(), 1); + return CreateBroadcast(shape, operands[0], dimensions_); + case HloOpcode::kCall: + return CreateCall(shape, operands, to_apply_); + case HloOpcode::kCustomCall: + return CreateCustomCall(shape, operands, custom_call_target_); + case HloOpcode::kConcatenate: + return CreateConcatenate(shape, operands, dimensions(0)); + case HloOpcode::kConvert: + CHECK_EQ(operands.size(), 1); + return CreateConvert(shape, operands[0]); + case HloOpcode::kConvolution: + CHECK_EQ(operands.size(), 2); + return CreateConvolve(shape, operands[0], operands[1], *window_, + *convolution_dimension_numbers_); + case HloOpcode::kCrossReplicaSum: + CHECK_EQ(operands.size(), 1); + return CreateCrossReplicaSum(shape, operands[0]); + case HloOpcode::kGetTupleElement: + CHECK_EQ(operands.size(), 1); + return CreateGetTupleElement(shape, operands[0], tuple_index()); + case HloOpcode::kMap: + return CreateMap(shape, operands, to_apply_); + case HloOpcode::kPad: + CHECK_EQ(operands.size(), 2); + return CreatePad(shape, operands[0], operands[1], *padding_config_); + case HloOpcode::kReduce: + CHECK_EQ(operands.size(), 2); + return CreateReduce(shape, operands[0], operands[1], dimensions_, + to_apply_); + case HloOpcode::kReduceWindow: + CHECK_EQ(operands.size(), 2); + return CreateReduceWindow(shape, operands[0], operands[1], *window_, + to_apply_); + case HloOpcode::kSelectAndScatter: + CHECK_EQ(operands.size(), 3); + return CreateSelectAndScatter(shape, operands[0], select_, *window_, + operands[1], operands[2], scatter_); + case HloOpcode::kRecv: + CHECK_EQ(operands.size(), 0); + return CreateRecv(shape); + case HloOpcode::kReverse: + CHECK_EQ(operands.size(), 1); + return CreateReverse(shape, operands[0], dimensions_); + case HloOpcode::kRng: + return CreateRng(shape, distribution_, operands); + case HloOpcode::kReshape: + CHECK_EQ(operands.size(), 1); + return CreateReshape(shape, operands[0]); + case HloOpcode::kSend: + CHECK_EQ(operands.size(), 1); + return CreateSend(operands[0]); + case HloOpcode::kSlice: + CHECK_EQ(operands.size(), 1); + return CreateSlice(shape, operands[0], slice_starts_, slice_limits_); + case HloOpcode::kDynamicSlice: + return CreateDynamicSlice(shape, operands[0], operands[1], + dynamic_slice_sizes_); + case HloOpcode::kDynamicUpdateSlice: + CHECK_EQ(operands.size(), 3); + return CreateDynamicUpdateSlice(shape, operands[0], operands[1], + operands[2]); + case HloOpcode::kTranspose: + CHECK_EQ(operands.size(), 1); + return CreateTranspose(shape, operands[0], dimensions_); + case HloOpcode::kTuple: + return CreateTuple(operands_); + case HloOpcode::kWhile: + CHECK_EQ(operands.size(), 1); + return CreateWhile(shape, condition_, body_, operands[0]); + case HloOpcode::kConstant: + return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); + case HloOpcode::kFusion: + return CloneFusionWithNewOperands(shape, operands); + case HloOpcode::kParameter: + return CreateParameter(parameter_number_, shape, parameter_name_); + // Unsupported ops for cloning. + case HloOpcode::kUpdate: + case HloOpcode::kIndex: + case HloOpcode::kInfeed: + case HloOpcode::kTrace: + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); + } +} + +std::unique_ptr HloInstruction::Clone() { + std::unique_ptr clone = + CloneWithNewOperands(shape_, operands_); + clone->name_ = name() + ".clone"; + return clone; +} + +std::unique_ptr HloInstruction::CloneFusionWithNewOperands( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + CHECK_EQ(opcode_, HloOpcode::kFusion); + + auto new_instruction = + WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); + // Add the operands to our new fusion instruction. + for (HloInstruction* new_operand : operands) { + new_instruction->AppendOperand(new_operand); + } + // Clone all the fused instructions for the new fusion instruction. + std::map old_to_new; + std::list> new_fused_instructions; + // Create the list of fused parameters by mapping through the cloned, + // fused instructions. + std::vector new_fused_parameters; + for (HloInstruction* old_fused_parameter : fused_parameters_) { + new_fused_instructions.push_back(old_fused_parameter->Clone()); + HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); + new_fusion_parameter->parent_fusion_instruction_ = new_instruction.get(); + new_fused_parameters.push_back(new_fusion_parameter); + InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); + } + for (auto old_fused_instruction_iter = fused_instructions_.rbegin(); + old_fused_instruction_iter != fused_instructions_.rend(); + ++old_fused_instruction_iter) { + HloInstruction* old_fused_instruction = old_fused_instruction_iter->get(); + if (old_fused_instruction->opcode() == HloOpcode::kParameter) { + FindOrDie(old_to_new, old_fused_instruction); + continue; + } + std::vector new_operands; + for (int64 operand_idx = 0; + operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { + HloInstruction* old_operand = + old_fused_instruction->mutable_operand(operand_idx); + new_operands.push_back(FindOrDie(old_to_new, old_operand)); + } + new_fused_instructions.push_back( + old_fused_instruction->CloneWithNewOperands( + old_fused_instruction->shape(), new_operands)); + HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); + new_fused_instruction->parent_fusion_instruction_ = new_instruction.get(); + InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); + } + // We iterated the fusion instructions in reverse post order which means + // that we must reverse our new list of fusion instructions. + std::reverse(new_fused_instructions.begin(), new_fused_instructions.end()); + new_instruction->fusion_kind_ = fusion_kind_; + new_instruction->fused_instructions_ = std::move(new_fused_instructions); + new_instruction->fused_parameters_ = std::move(new_fused_parameters); + new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_); + new_instruction->CheckFusionInstruction(); + return new_instruction; +} + +const Literal& HloInstruction::literal() const { + CHECK_EQ(HloOpcode::kConstant, opcode_); + return *literal_; +} + +bool HloInstruction::CanHaveDimensionsField() const { + return (opcode() == HloOpcode::kReverse || + opcode() == HloOpcode::kConcatenate || + opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || + opcode() == HloOpcode::kTranspose); +} + +const std::vector& HloInstruction::dimensions() const { + CHECK(CanHaveDimensionsField()); + return dimensions_; +} + +int64 HloInstruction::dimensions(int64 index) const { + return dimensions()[index]; +} + +int64 HloInstruction::concatenate_dimension() const { + CHECK(opcode() == HloOpcode::kConcatenate); + CHECK_EQ(1, dimensions_.size()); + return dimensions(0); +} + +int64 HloInstruction::tuple_index() const { + CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); + return tuple_index_; +} + +const HloInstruction* HloInstruction::operand(int64 i) const { + return operands_[i]; +} + +HloInstruction* HloInstruction::mutable_operand(int64 i) { + CHECK(operands_[i] != nullptr); + return operands_[i]; +} + +int64 HloInstruction::operand_index(const HloInstruction* target) const { + for (int64 i = 0; i < operand_count(); ++i) { + if (target == operand(i)) { + return i; + } + } + LOG(FATAL) << "target was not an operand"; +} + +void HloInstruction::AppendOperand(HloInstruction* operand) { + operands_.push_back(operand); + operand->AddUser(this); +} + +void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); } + +bool HloInstruction::IsConstant() const { + return opcode_ == HloOpcode::kConstant; +} + +bool HloInstruction::HasConstantOperand() const { + for (const HloInstruction* operand : operands_) { + if (operand->IsConstant()) { + return true; + } + } + return false; +} + +void HloInstruction::AddControlPredecessor(HloInstruction* instruction) { + control_predecessors_.insert(instruction); +} + +bool HloInstruction::Identical( + const HloInstruction& other, + std::function + eq_operands, + std::function + eq_computations) const { + // An instruction is always identical to itself. + if (this == &other) { + return true; + } + + // Identical instruction must have the same opcode and identical operands. In + // general, there is no need to check shape because shape is inferred from the + // shape of the operands. + if (opcode() != other.opcode() || + !ContainersEqual(operands(), other.operands(), eq_operands)) { + return false; + } + + // Perform opcode specific checks. + switch (opcode()) { + // The result of these instructions only depend upon their opcode and + // operands. + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kCopy: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kDivide: + case HloOpcode::kDot: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalNot: + case HloOpcode::kLogicalOr: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTuple: + return true; + + // These opcodes have complex or special behavior so just return false. + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kTrace: + case HloOpcode::kWhile: + return false; + + case HloOpcode::kParameter: + return parameter_number() == other.parameter_number() && + // Check the shape too because `this` and `other` may be in + // different HloComputations. + ShapeUtil::Compatible(shape(), other.shape()); + + // A constant is defined by the value in the literal. + case HloOpcode::kConstant: + return LiteralUtil::Equal(literal(), other.literal()); + + // A convert result is determined by the primitive type that the operand is + // converted into. + case HloOpcode::kConvert: + return shape().element_type() == other.shape().element_type(); + + // Convolution has a window and dimensions. + case HloOpcode::kConvolution: + return protobuf_util::ProtobufEquals(window(), other.window()) && + protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()); + + // Reduction results are determined by the reduction dimension and the + // reduction computation. + case HloOpcode::kReduce: + return dimensions() == other.dimensions() && + eq_computations(to_apply(), other.to_apply()); + case HloOpcode::kReduceWindow: + return eq_computations(to_apply(), other.to_apply()) && + protobuf_util::ProtobufEquals(window(), other.window()); + + // SelectAndScatter is determined by both select and scatter + // computation as well as the window configuration. + case HloOpcode::kSelectAndScatter: + return eq_computations(select(), other.select()) && + eq_computations(scatter(), other.scatter()) && + protobuf_util::ProtobufEquals(window(), other.window()); + + case HloOpcode::kReshape: + return ShapeUtil::Compatible(shape(), other.shape()); + + // Transpose result is determined by the final shape and the permutation. + case HloOpcode::kTranspose: + return ShapeUtil::Compatible(shape(), other.shape()) && + dimensions() == other.dimensions(); + + // Remaining instructions with special values. + case HloOpcode::kBitcast: + return ShapeUtil::Equal(shape(), other.shape()); + case HloOpcode::kBroadcast: + return ShapeUtil::Compatible(shape(), other.shape()) && + dimensions() == other.dimensions(); + case HloOpcode::kConcatenate: + return dimensions() == other.dimensions(); + case HloOpcode::kGetTupleElement: + return tuple_index() == other.tuple_index(); + case HloOpcode::kPad: + return protobuf_util::ProtobufEquals(padding_config(), + other.padding_config()); + case HloOpcode::kSlice: + return slice_starts_ == other.slice_starts_ && + slice_limits_ == other.slice_limits_; + case HloOpcode::kDynamicSlice: + return ShapeUtil::Compatible(shape(), other.shape()) && + dynamic_slice_sizes_ == other.dynamic_slice_sizes_; + case HloOpcode::kDynamicUpdateSlice: + return ShapeUtil::Compatible(shape(), other.shape()); + case HloOpcode::kCall: + case HloOpcode::kMap: + return eq_computations(to_apply(), other.to_apply()); + case HloOpcode::kCustomCall: + return custom_call_target_ == other.custom_call_target_; + case HloOpcode::kReverse: + return dimensions() == other.dimensions(); + + // These opcodes are not yet supported. + case HloOpcode::kIndex: + case HloOpcode::kInfeed: + case HloOpcode::kSort: + case HloOpcode::kUpdate: + case HloOpcode::kSend: + case HloOpcode::kRecv: + return false; + } +} + +bool HloInstruction::IsRank2Transpose() const { + return (opcode_ == HloOpcode::kTranspose) && + dimensions_ == std::vector({1, 0}) && + shape_.dimensions_size() == 2 && + std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), + operands_[0]->shape_.dimensions().rbegin()); +} + +void HloInstruction::RemoveUser(HloInstruction* user) { + auto user_it = users_.find(user); + CHECK(user_it != users_.end()); + users_.erase(user_it); +} + +void HloInstruction::ReplaceUseWith(HloInstruction* user, + HloInstruction* new_producer) { + CHECK(ShapeUtil::Compatible(shape(), new_producer->shape())) + << "this shape: " << ShapeUtil::HumanString(shape()) + << ", replacement shape: " + << ShapeUtil::HumanString(new_producer->shape()); + auto user_it = std::find(users_.begin(), users_.end(), user); + CHECK(user_it != users_.end()) << "Instruction " << user + << " not a use of instruction " << this; + users_.erase(user_it); + + CHECK_GT(std::count(user->operands_.begin(), user->operands_.end(), this), 0); + std::replace(user->operands_.begin(), user->operands_.end(), this, + new_producer); + new_producer->AddUser(user); +} + +void HloInstruction::ReplaceOperandWith(int64 operand_num, + HloInstruction* new_operand) { + CHECK_GE(operand_num, 0); + CHECK_LT(operand_num, operand_count()); + HloInstruction* old_operand = mutable_operand(operand_num); + CHECK(ShapeUtil::Compatible(old_operand->shape(), new_operand->shape())) + << old_operand->shape().ShortDebugString() << " is not compatible with " + << new_operand->shape().ShortDebugString(); + operands_[operand_num] = new_operand; + + if (std::find(operands_.begin(), operands_.end(), old_operand) == + operands_.end()) { + old_operand->RemoveUser(this); + } + new_operand->AddUser(this); +} + +void HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { + // We can't use range-based loop because the iterator is invalidated by call + // to ReplaceUseWith. + for (auto user = users_.begin(); user != users_.end();) { + auto this_user = user; + user++; + // It's possible that new_producer is a user of this instruction as might + // be the case when replacing an instruction with a kCopy of itself. In + // this case, don't do the replacement to avoid creating a cycle in the + // graph. + if (*this_user != new_producer) { + ReplaceUseWith(*this_user, new_producer); + } + } +} + +void HloInstruction::DetachFromOperands() { + CHECK_EQ(0, user_count()); + // An intruction may be repeated as an operand. To avoid calling RemoveUser + // twice on the same operand, keep a set of already detached operands. + std::set detached_operands; + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + HloInstruction* operand = operands_[operand_num]; + if (detached_operands.count(operand) == 0) { + operand->RemoveUser(this); + detached_operands.insert(operand); + } + operands_[operand_num] = nullptr; + } +} + +HloComputation* HloInstruction::to_apply() const { + switch (opcode_) { + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + return to_apply_; + default: + LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + } +} + +void HloInstruction::set_to_apply(HloComputation* computation) { + switch (opcode_) { + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + to_apply_ = computation; + break; + default: + LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + } +} + +const string& HloInstruction::custom_call_target() const { + CHECK_EQ(opcode_, HloOpcode::kCustomCall); + return custom_call_target_; +} + +HloComputation* HloInstruction::while_condition() const { + CHECK_EQ(HloOpcode::kWhile, opcode_); + return condition_; +} + +HloComputation* HloInstruction::while_body() const { + CHECK_EQ(HloOpcode::kWhile, opcode_); + return body_; +} + +void HloInstruction::set_while_condition(HloComputation* computation) { + CHECK_EQ(HloOpcode::kWhile, opcode_); + condition_ = computation; +} + +void HloInstruction::set_while_body(HloComputation* computation) { + CHECK_EQ(HloOpcode::kWhile, opcode_); + body_ = computation; +} + +HloComputation* HloInstruction::select() const { + CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); + return select_; +} + +HloComputation* HloInstruction::scatter() const { + CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); + return scatter_; +} + +void HloInstruction::set_select(HloComputation* computation) { + CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); + select_ = computation; +} + +void HloInstruction::set_scatter(HloComputation* computation) { + CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); + scatter_ = computation; +} + +string HloInstruction::SignatureString() const { + string operands = tensorflow::str_util::Join( + operands_, ", ", [](string* out, HloInstruction* operand) { + tensorflow::strings::StrAppend( + out, ShapeUtil::HumanString(operand->shape())); + }); + return tensorflow::strings::StrCat("(", operands, ") -> ", + ShapeUtil::HumanString(shape())); +} + +string HloInstruction::ToString() const { + string operands; + if (opcode() == HloOpcode::kConstant) { + // For constants, emit the actual value in place of an empty operand list. + if (ShapeUtil::ElementsIn(shape()) <= 10) { + // LiteralUtil::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = LiteralUtil::ToString(literal()); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) continue; + tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Don't try emitting large constants. + operands = "{...}"; + } + } else { + operands = tensorflow::str_util::Join( + operands_, ", ", [](string* out, HloInstruction* operand) { + tensorflow::strings::StrAppend( + out, ShapeUtil::HumanStringWithLayout(operand->shape()), " ", + operand->name()); + }); + } + string extra; + if (LayoutUtil::HasLayout(shape())) { + if (ShapeUtil::IsTuple(shape())) { + // Tuple shapes are recursive, so the layout field of the top-level shape + // does not include all layout information. In this case, print out the + // entire shape with layout. + tensorflow::strings::StrAppend(&extra, ", layout=", + ShapeUtil::HumanStringWithLayout(shape())); + } else { + tensorflow::strings::StrAppend( + &extra, tensorflow::strings::Printf( + ", layout=%s", + LayoutUtil::HumanString(shape().layout()).c_str())); + } + } + if (CanHaveDimensionsField()) { + tensorflow::strings::StrAppend( + &extra, ", dimensions={", + tensorflow::str_util::Join(dimensions(), ", "), "}"); + } + if (window_ != nullptr) { + tensorflow::strings::StrAppend(&extra, ", window=", + window_util::ToString(*window_)); + } + if (padding_config_ != nullptr) { + tensorflow::strings::StrAppend(&extra, ", padding=", + padding_config_->ShortDebugString()); + } + if (convolution_dimension_numbers_ != nullptr) { + tensorflow::strings::StrAppend( + &extra, + tensorflow::strings::Printf( + ", " + "conv_dim_nums={batch_dim=%lld,feature_dim=%lld,spatial_dims=(%s)," + "kernel_input_feature_dims=%lld,kernel_output_feature_dim=%lld," + "kernel_spatial_dims=(%s)}", + convolution_dimension_numbers_->batch_dimension(), + convolution_dimension_numbers_->feature_dimension(), + tensorflow::str_util::Join( + convolution_dimension_numbers_->spatial_dimensions(), ",") + .c_str(), + convolution_dimension_numbers_->kernel_input_feature_dimension(), + convolution_dimension_numbers_->kernel_output_feature_dimension(), + tensorflow::str_util::Join( + convolution_dimension_numbers_->kernel_spatial_dimensions(), + ",") + .c_str())); + } + if (to_apply_ != nullptr) { + tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name()); + } + if (opcode() == HloOpcode::kWhile) { + tensorflow::strings::StrAppend(&extra, ", condition=", + while_condition()->name()); + tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name()); + } + if (opcode() == HloOpcode::kGetTupleElement) { + tensorflow::strings::StrAppend(&extra, ", index=", tuple_index()); + } + return tensorflow::strings::Printf( + "%s %s = %s(%s)%s", ShapeUtil::HumanString(shape()).c_str(), + name().c_str(), HloOpcodeString(opcode()).c_str(), operands.c_str(), + extra.c_str()); +} + +string HloInstruction::ToShortString() const { + return tensorflow::strings::Printf( + "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(), + tensorflow::str_util::Join(operands_, ", ", [](string* out, + HloInstruction* operand) { + tensorflow::strings::StrAppend(out, operand->name()); + }).c_str()); +} + +HloInstruction* HloInstruction::tracing() const { return trace_instruction_; } + +void HloInstruction::set_tracing(HloInstruction* trace_instruction) { + trace_instruction_ = trace_instruction; +} + +const string& HloInstruction::tracing_tag() const { + CHECK_EQ(HloOpcode::kTrace, opcode()); + CHECK(literal_ != nullptr); + return literal_->u8s(); +} + +bool HloInstruction::IsFused() const { + return parent_fusion_instruction_ != nullptr; +} + +bool HloInstruction::IsFusable() const { + // Instructions which are traced should not be fused. + if (tracing()) { + return false; + } + + // Some kinds of instructions don't make sense to fuse. + switch (opcode_) { + case HloOpcode::kFusion: + case HloOpcode::kInfeed: + case HloOpcode::kParameter: + case HloOpcode::kTrace: + case HloOpcode::kSend: + case HloOpcode::kRecv: + return false; + default: + return true; + } +} + +HloInstruction* HloInstruction::fusion_instruction() const { + CHECK(IsFused()); + return parent_fusion_instruction_; +} + +HloInstruction* HloInstruction::fused_expression_root() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + return fused_root_; +} + +HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK_GE(parameter_number, 0); + CHECK_LT(parameter_number, fused_parameters_.size()); + return fused_parameters_[parameter_number]; +} + +const std::list>& +HloInstruction::fused_instructions() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + return fused_instructions_; +} + +HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) + : shape_(shape), opcode_(opcode), name_("%" + HloOpcodeString(opcode)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); +} + +Status HloInstruction::AcceptInternalVisit(DfsHloVisitor* visitor) { + switch (opcode_) { + case HloOpcode::kAbs: + return visitor->HandleAbs(this, operands_[0]); + case HloOpcode::kSign: + return visitor->HandleSign(this, operands_[0]); + case HloOpcode::kConstant: + return visitor->HandleConstant(this, *literal_); + case HloOpcode::kGetTupleElement: + return visitor->HandleGetTupleElement(this, operands_[0]); + case HloOpcode::kParameter: + return visitor->HandleParameter(this); + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]); + case HloOpcode::kAdd: + return visitor->HandleAdd(this, operands_[0], operands_[1]); + case HloOpcode::kDivide: + return visitor->HandleDivide(this, operands_[0], operands_[1]); + case HloOpcode::kSubtract: + return visitor->HandleSubtract(this, operands_[0], operands_[1]); + case HloOpcode::kMaximum: + return visitor->HandleMaximum(this, operands_[0], operands_[1]); + case HloOpcode::kMinimum: + return visitor->HandleMinimum(this, operands_[0], operands_[1]); + case HloOpcode::kLogicalAnd: + return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); + case HloOpcode::kLogicalOr: + return visitor->HandleLogicalOr(this, operands_[0], operands_[1]); + case HloOpcode::kConcatenate: + return visitor->HandleConcatenate(this, operands_); + case HloOpcode::kConvert: + return visitor->HandleConvert(this, operands_[0]); + case HloOpcode::kCopy: + return visitor->HandleCopy(this, operands_[0]); + case HloOpcode::kMultiply: + return visitor->HandleMultiply(this, operands_[0], operands_[1]); + case HloOpcode::kDot: + return visitor->HandleDot(this, operands_[0], operands_[1]); + case HloOpcode::kPower: + return visitor->HandlePower(this, operands_[0], operands_[1]); + case HloOpcode::kRemainder: + return visitor->HandleRemainder(this, operands_[0], operands_[1]); + case HloOpcode::kSelect: + return visitor->HandleSelect(this, operands_[0], operands_[1], + operands_[2]); + case HloOpcode::kConvolution: + return visitor->HandleConvolution(this, operands_[0], operands_[1], + window()); + case HloOpcode::kCrossReplicaSum: + return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kTuple: + return visitor->HandleTuple(this, operands_); + case HloOpcode::kMap: + return visitor->HandleMap(this, operands_, to_apply_, {}); + case HloOpcode::kClamp: + return visitor->HandleClamp(this, operands_[0], operands_[1], + operands_[2]); + case HloOpcode::kReduce: + return visitor->HandleReduce(this, operands_[0], operands_[1], + dimensions_, to_apply_); + case HloOpcode::kReduceWindow: + return visitor->HandleReduceWindow(this, operands_[0], window(), + to_apply_); + case HloOpcode::kSelectAndScatter: + return visitor->HandleSelectAndScatter(this); + case HloOpcode::kNegate: + return visitor->HandleNegate(this, operands_[0]); + case HloOpcode::kExp: + return visitor->HandleExp(this, operands_[0]); + case HloOpcode::kFloor: + return visitor->HandleFloor(this, operands_[0]); + case HloOpcode::kCeil: + return visitor->HandleCeil(this, operands_[0]); + case HloOpcode::kLog: + return visitor->HandleLog(this, operands_[0]); + case HloOpcode::kTanh: + return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kLogicalNot: + return visitor->HandleLogicalNot(this, operands_[0]); + case HloOpcode::kBitcast: + return visitor->HandleBitcast(this); + case HloOpcode::kBroadcast: + return visitor->HandleBroadcast(this); + case HloOpcode::kPad: + return visitor->HandlePad(this); + case HloOpcode::kReshape: + return visitor->HandleReshape(this); + case HloOpcode::kTranspose: + return visitor->HandleTranspose(this); + case HloOpcode::kReverse: + return visitor->HandleReverse(this, operands_[0]); + case HloOpcode::kSlice: + return visitor->HandleSlice(this, operands_[0]); + case HloOpcode::kDynamicSlice: + return visitor->HandleDynamicSlice(this, operands_); + case HloOpcode::kDynamicUpdateSlice: + return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1], + operands_[2]); + case HloOpcode::kSort: + return visitor->HandleSort(this, operands_[0]); + case HloOpcode::kInfeed: + return visitor->HandleInfeed(this); + case HloOpcode::kRng: + return visitor->HandleRng(this, distribution_); + case HloOpcode::kWhile: + return visitor->HandleWhile(this, operands_[0], condition_, body_); + case HloOpcode::kFusion: + return visitor->HandleFusion(this); + case HloOpcode::kCall: + return visitor->HandleCall(this, operands_, to_apply_); + case HloOpcode::kCustomCall: + return visitor->HandleCustomCall(this, operands_, custom_call_target_); + case HloOpcode::kSend: + return visitor->HandleSend(this); + case HloOpcode::kRecv: + return visitor->HandleRecv(this); + + // These opcodes are not handled here. + case HloOpcode::kIndex: + case HloOpcode::kTrace: + case HloOpcode::kUpdate: + break; + } + return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s", + HloOpcodeString(opcode_).c_str()); +} + +Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor) { + // Do not visit this HLO node again if it is already visited. + if (visitor->DidVisit(*this)) { + VLOG(3) << "Not visiting HLO " << this << " as it was already visited."; + return Status::OK(); + } + + // If the instruction is in the visiting state, it means a cycle. + if (visitor->IsVisiting(*this)) { + return FailedPrecondition( + "A cycle is detected while visiting instruction %s", + ToString().c_str()); + } + visitor->SetVisiting(*this); + + for (auto operand : operands_) { + VLOG(3) << "Going to visit HLO " << operand << " as operand of HLO " + << this; + TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor)); + } + + for (auto control_predecessor : control_predecessors_) { + VLOG(3) << "Going to visit HLO " << control_predecessor + << " as a control predecessor of HLO " << this; + TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal(visitor)); + } + + TF_RETURN_IF_ERROR(visitor->Preprocess(this)); + VLOG(3) << "Visiting HLO " << this; + TF_RETURN_IF_ERROR(AcceptInternalVisit(visitor)); + visitor->SetVisited(*this); + return visitor->Postprocess(this); +} + +Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit) { + auto status = AcceptInternal(visitor); + if (!status.ok()) { + return status; + } + + if (call_finish_visit) { + return visitor->FinishVisit(this); + } else { + return Status::OK(); + } +} + +namespace { + +// Returns true if the given order is a topological sort of exactly those +// instructions rooted at 'root'. +bool OrderIsTopologicalSort(HloInstruction* root, + const std::vector& order) { + // Create a map from instruction to its position in 'order'. + std::unordered_map order_position; + for (int i = 0; i < order.size(); i++) { + if (!order_position.insert(std::make_pair(order[i], i)).second) { + // Instruction order[i] is duplicated in the order. + return false; + } + } + // Verify that the operand of each instruction in the order is also in the + // order *and* the operand's position is earlier (defs are before uses for all + // ops). + for (auto* instruction : order) { + for (auto* operand : instruction->operands()) { + if (order_position.count(operand) == 0 || + order_position.at(operand) >= order_position.at(instruction)) { + return false; + } + } + } + + // Create a vector of all instructions in a DFS search starting at + // root. 'order' should contain exactly these instructions. + std::vector visited; + TF_CHECK_OK(root->Accept([&visited](HloInstruction* instruction) { + visited.push_back(instruction); + return Status::OK(); + })); + + if (order_position.size() != visited.size()) { + return false; + } + for (auto* instruction : visited) { + if (order_position.count(instruction) == 0) { + return false; + } + } + // Given the conditions above, the last element of order should always be the + // root. + CHECK_EQ(root, order[order.size() - 1]); + + return true; +} + +} // namespace + +Status HloInstruction::Accept(FunctionVisitor::VisitorFunction visitor_func) { + FunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + +Status HloInstruction::AcceptOrdered( + DfsHloVisitor* visitor, const std::vector& order) { + DCHECK(OrderIsTopologicalSort(this, order)); + for (auto* const_instruction : order) { + // The visitor can mark instructions as visited to skip particular + // instructions. + if (visitor->DidVisit(*const_instruction)) { + VLOG(3) << "Not visiting HLO " << const_instruction + << " as it was already visited."; + continue; + } + + HloInstruction* instruction = + const_cast(const_instruction); + + TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); + VLOG(3) << "Visiting HLO " << instruction; + TF_RETURN_IF_ERROR(instruction->AcceptInternalVisit(visitor)); + visitor->SetVisited(*instruction); + TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); + } + + return visitor->FinishVisit(this); +} + +const Shape& HloInstruction::shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); + return shape_; +} + +std::vector HloInstruction::OperandIndices( + const HloInstruction* operand) const { + std::vector result; + for (int64 i = 0; i < operand_count(); ++i) { + if (this->operand(i) == operand) { + result.push_back(i); + } + } + return result; +} + +bool HloInstruction::IsElementwise() const { + switch (opcode_) { + // Nullary elementwise operations. + case HloOpcode::kConstant: + return true; + + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kCeil: + case HloOpcode::kConvert: + case HloOpcode::kCopy: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kLogicalNot: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kTanh: + return true; + + // Binary elementwise operations. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalOr: + return true; + + // Ternary elementwise operations. + case HloOpcode::kSelect: + return !ShapeUtil::IsTuple(shape_); + case HloOpcode::kClamp: + return true; + + // Other operations. + case HloOpcode::kMap: + return true; + case HloOpcode::kFusion: + if (fusion_kind() != FusionKind::kLoop) { + return false; + } + for (auto& fused : fused_instructions()) { + if (fused->opcode() != HloOpcode::kParameter && + !fused->IsElementwise()) { + return false; + } + } + return true; + + default: + return false; + } +} + +namespace { +bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, + const HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + return std::all_of( + operand_indices.begin(), operand_indices.end(), + [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); +} +} // namespace + +bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { + // For all instructions other than kFusion, being elementwise on one of the + // operands is equivalent to being elementwise on all the operands. + if (opcode() != HloOpcode::kFusion) { + return IsElementwise(); + } + + CHECK_EQ(HloOpcode::kFusion, opcode()); + if (fusion_kind() != FusionKind::kLoop) { + return false; + } + + // A loop-fusion is elementwise on an operand if all operations (computed + // using BFS) between the operand and the fused root are elementwise. + std::deque worklist; + std::unordered_set visited; + worklist.push_back(fused_parameter(operand_idx)); + visited.insert(fused_parameter(operand_idx)); + while (!worklist.empty()) { + HloInstruction* operand = worklist.front(); + worklist.pop_front(); + for (HloInstruction* user : operand->users()) { + if (visited.count(user)) { + continue; + } + if (user->IsElementwise() || + IsInstructionElementwiseOnOperand(user, operand)) { + worklist.push_back(user); + visited.insert(user); + } else { + return false; + } + } + } + return true; +} + +HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { + switch (opcode_) { + case HloOpcode::kBitcast: + case HloOpcode::kConcatenate: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kTranspose: + return UseKind::kUsePermutingElements; + case HloOpcode::kPad: + case HloOpcode::kReduce: + // Pad reuses the padding value but not the padded array elements. + // Reduce reuses the init value but not the operand array elements. + return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; + case HloOpcode::kFusion: { + tensorflow::gtl::FlatMap cache; + // We could rather iterate backwards thru fused_instructions_ here, as it + // is in reverse postorder, and compute whether each fused instruction + // reuses the value of this parameter, which would save stack space but + // not allow us to finish early if we find a reuse. + std::function reuses_parameter_elements = + [i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) { + auto plus = [](const UseKind& a, const UseKind& b) { + if (a == UseKind::kNoUse) return b; + if (b == UseKind::kNoUse) return a; + if (a == UseKind::kReuse || b == UseKind::kReuse) { + return UseKind::kReuse; + } + if (a == UseKind::kUsePermutingElements || + b == UseKind::kUsePermutingElements) { + return UseKind::kReuse; + } + CHECK(UseKind::kUse == a && UseKind::kUse == b); + return UseKind::kUse; + }; + + if (hlo.opcode_ == HloOpcode::kParameter && + hlo.parameter_number_ == i) { + return UseKind::kUse; + } + if (cache.count(&hlo) == 0) { + for (int64 j = 0; j < hlo.operands_.size(); ++j) { + UseKind old = cache[&hlo]; + UseKind updated = plus( + old, std::min(hlo.OperandElementUse(j), + reuses_parameter_elements(*hlo.operand(j)))); + cache[&hlo] = updated; + } + } + return cache[&hlo]; + }; + return reuses_parameter_elements(*fused_root_); + } + default: + return IsElementwise() ? UseKind::kUse : UseKind::kReuse; + } +} + +namespace { + +// Prereq: `order` is a permutation of {0, 1, ..., `dims.size()-1`} +void Strip1SizedDimensions(tensorflow::protobuf::RepeatedField* dims, + std::vector* order) { + // We can't merely call StripDegenerateDimensions here as we must also delete + // the dimension indices. + for (size_t i = 0; i < dims->size(); ++i) { + if (1 == dims->Get(i)) { + dims->erase(dims->begin() + i); + // We must find this, as order must be a permutation of operand + // dimensions. + order->erase(std::find(order->begin(), order->end(), i)); + } + } +} + +} // namespace + +std::tuple, std::vector> +HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const { + if (HloOpcode::kReshape != opcode_) { + return std::make_tuple(false, std::vector(), std::vector()); + } + return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_, + shape_); +} + +string FusionKindString(HloInstruction::FusionKind kind) { + switch (kind) { + case HloInstruction::FusionKind::kLoop: + return "Loop"; + case HloInstruction::FusionKind::kInput: + return "Input"; + case HloInstruction::FusionKind::kTransposeDot: + return "TransposeDot"; + case HloInstruction::FusionKind::kConvBackwardFilter: + return "ConvBackwardFilter"; + case HloInstruction::FusionKind::kConvBackwardInput: + return "ConvBackwardInput"; + } +} + +bool HloInstruction::CouldBeBitcast() const { + switch (opcode_) { + case HloOpcode::kTranspose: + return true; + case HloOpcode::kReshape: + return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + default: + return false; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h new file mode 100644 index 0000000000..8e7a253578 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -0,0 +1,791 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// HLO instructions are in DAG form and represent the computations that the user +// has built up via the XLA service interface. They are ultimately lowered +// in a platform-aware way by traversing the HLO DAG and emitting a lowered +// form; e.g. see DfsHloVisitor. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloComputation; + +// HLO instructions are the IR used by the high-level compiler. +class HloInstruction { + public: + enum class FusionKind { + kLoop, // Fused into a loop. + kInput, // Fused into a reduction kernel. + kTransposeDot, // Fused into a dot with transposed operands. + kConvBackwardFilter, // Fused into a backward filter convolution. + kConvBackwardInput, // Fused into a backward input convolution. + }; + + // Creates a parameter-retrieving instruction. + static std::unique_ptr CreateParameter(int64 parameter_number, + const Shape& shape, + const string& name); + + // Creates a literal constant instruction. + static std::unique_ptr CreateConstant( + std::unique_ptr literal); + + // Creates a get tuple element instruction. + static std::unique_ptr CreateGetTupleElement( + const Shape& shape, HloInstruction* operand, int64 index); + + // Creates a trace instruction that logs the input operand in the computation. + static std::unique_ptr CreateTrace(const string& tag, + HloInstruction* operand); + + // Creates a random number generation instruction that fills a shape with + // random numbers from a given distribution. + static std::unique_ptr CreateRng( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters); + + // Creates an n-ary elementwise operation. + static std::unique_ptr CreateNary( + const Shape& shape, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands); + + // Creates a unary instruction (one operand). + // Precondition: opcode must be a legitimate unary operation. + static std::unique_ptr CreateUnary(const Shape& shape, + HloOpcode opcode, + HloInstruction* operand); + + // Creates a binary instruction (two operands). + // Precondition: opcode must be a legitimate binary operation. + static std::unique_ptr CreateBinary(const Shape& shape, + HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs); + + // Creates a ternary instruction (three operands). + // Precondition: opcode must be a legitimate ternary operation. + static std::unique_ptr CreateTernary(const Shape& shape, + HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs, + HloInstruction* ehs); + + // Creates a variadic instruction (variable number of operands). + // Precondition: opcode must be a legitimate variadic operation. + static std::unique_ptr CreateVariadic( + const Shape& shape, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands); + + // Creates a map instruction, where the computation (given by the handle) is + // applied element-wise to every element in operands (across the operands, + // at a given index) with the same `static_operands`. + static std::unique_ptr CreateMap( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Creates a convolution op, where rhs is the convolutional filter + // and window describes how the filter is applied to lhs. + static std::unique_ptr CreateConvolve( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Creates a cross replica sum op. + static std::unique_ptr CreateCrossReplicaSum( + const Shape& shape, HloInstruction* operand); + + // Creates a conversion instruction, where operand is the data to convert and + // shape is the target shape for the conversion. + static std::unique_ptr CreateConvert(const Shape& shape, + HloInstruction* operand); + + // Creates an infeed instruction, which reads data of the given shape from the + // Infeed interface of the device. + static std::unique_ptr CreateInfeed(const Shape& shape); + + // Creates a send instruction, which sends the operand data to a receive + // instruction in another computation. + static std::unique_ptr CreateSend(HloInstruction* operand); + + // Creates a receive instruction, which receives data of the given shape + // from a send instruction in another computation. + static std::unique_ptr CreateRecv(const Shape& shape); + + // Creates a slice instruction, where the operand is sliced by the given + // start/limit indices. + static std::unique_ptr CreateSlice( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices); + + // Creates a slice instruction, where the first operand is sliced by + // start indices specified in the second operand, and by size specfied in + // 'slice_sizes'. + static std::unique_ptr CreateDynamicSlice( + const Shape& shape, HloInstruction* operand, + HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Creates a dynamic update slice instruction, which updates a slice + // of 'operand' with 'update' and 'start_indices'. + static std::unique_ptr CreateDynamicUpdateSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices); + + // Creates a concatenate instruction, where the operands are concatenated on + // the provided dimension. + static std::unique_ptr CreateConcatenate( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Creates a reduce instruction, where the computation (given by the handle) + // is applied successively to every element in operand. That is, if f is the + // function to apply (which either takes 2 [accumulator, value] or 3 + // [accumulator, index, value] arguments) and init is a reduction operator + // specified initial value (for example, 0 for addition), then this operation + // will compute: + // f(f(init, [index0], value0), [index1], value1), ...) + static std::unique_ptr CreateReduce( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + + // Creates a reduce-window instruction, where the computation (given + // by the handle) is applied window-wise at each valid window + // position in the operand. + static std::unique_ptr CreateReduceWindow( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + const Window& window, HloComputation* reduce_computation); + + // Creates a scatter computation that scatters the `source` array to the + // selected indices of each window. + static std::unique_ptr CreateSelectAndScatter( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter); + + // Creates a broadcast instruction. + static std::unique_ptr CreateBroadcast( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Creates a pad instruction, where the operand is padded on the edges and + // between the elements with the given padding value. + static std::unique_ptr CreatePad( + const Shape& shape, HloInstruction* operand, + HloInstruction* padding_value, const PaddingConfig& padding_config); + + // Creates a reshape instruction, where the operand is flattened row-major + // order and then reshaped to the given result shape. + static std::unique_ptr CreateReshape(const Shape& shape, + HloInstruction* operand); + + // Creates a transpose instruction which permutes the operand dimensions. + static std::unique_ptr CreateTranspose( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + + // Creates a while instruction, given a condition computation, a body + // computation, and the initial value for the input of the computations. For + // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 + // corresponds to the C code below. + // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } + static std::unique_ptr CreateWhile(const Shape& shape, + HloComputation* condition, + HloComputation* body, + HloInstruction* init); + + // Creates a fusion instruction. A fusion instruction contains one or more + // fused instructions forming an expression with a single root + // "fused_root". Additional instructions can be added to the fusion + // instruction with the method FuseInstruction. + static std::unique_ptr CreateFusion( + const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + + // Creates a fusion instruction that represents backward convolution. This is + // similar to CreateFusion, but with extra arguments indicating the window and + // dimemsion mapping of the backward convolution. + static std::unique_ptr CreateFusionForBackwardConvolution( + const Shape& shape, FusionKind fusion_kind, const Window& window, + const ConvolutionDimensionNumbers& conv_dnums, + HloInstruction* fused_root); + + // Creates a call instruction that applies the given computation on the given + // operands. "shape" is the resultant shape. + static std::unique_ptr CreateCall( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* computation); + + // Creates a custom call instruction that applies the given custom call target + // to the given operands. "shape" is the resultant shape. + static std::unique_ptr CreateCustomCall( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target); + + // Creates a tuple instruction with the given elements. This is a convenience + // wrapper around CreateVariadic. + static std::unique_ptr CreateTuple( + tensorflow::gtl::ArraySlice elements); + + // Creates a reverse instruction, which reverses the order of the elements + // in the specified dimensions. + static std::unique_ptr CreateReverse( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + + // Returns the opcode for this instruction. + HloOpcode opcode() const { return opcode_; } + + // Returns the result shape of this instruction. + const Shape& shape() const; + + // Returns the (mutable) result shape of this instruction. + Shape* mutable_shape() { return &shape_; } + + // Returns the ith operand to this instruction. + const HloInstruction* operand(int64 i) const; + + // Returns the ith operand to this instruction. + HloInstruction* mutable_operand(int64 i); + + // Returns the number of operands to this instruction. + int64 operand_count() const { return operands_.size(); } + + // Returns the vector of operands of this instruction. + const std::vector& operands() const { return operands_; } + + // Returns the index of 'target' in the operands sequence. + // Precondition: target must be an operand (or a fatal error will occur). + int64 operand_index(const HloInstruction* target) const; + + // Returns the number of users of this instruction. + int64 user_count() const { return users_.size(); } + + // Returns the users of this instruction. + const std::set& users() const { return users_; } + + // Returns the set of control predecessors of this instruction. Control + // predecessors are the instructions that must be scheduled before the + // current instruction. + const std::set& control_predecessors() const { + return control_predecessors_; + } + + // Adds the given instruction to the set of control predecessors. + void AddControlPredecessor(HloInstruction* instruction); + + // Returns true if "other" performs the same computation as this instruction. + // Layout of the instructions' output array is not considered. + bool Identical( + const HloInstruction& other, + std::function + eq_operands = std::equal_to(), + std::function + eq_computations = std::equal_to()) const; + + // Returns whether the instruction has a constant operand. + bool HasConstantOperand() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Replaces the use of this instruction in "user" with "new_producer". Note + // that there might be multiple uses of this instruction in "user"; all will + // be replaced. + void ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); + + // Replaces the specified operand with new_operand. + void ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); + + // Replaces all uses of this instruction with the new producer. If + // new_producer is a user of this instruction then new_producer remains a use + // of this instruction to avoid introducing cycles into the graph. + void ReplaceAllUsesWith(HloInstruction* new_producer); + + // Detaches an instruction from its operands. That is, remove the instruction + // from each operand's user set. This should only be called prior to + // deallocating the instruction. + void DetachFromOperands(); + + // Performs a postorder DFS visit using this node as the root. If + // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when + // complete. + Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true); + + // Performs a postorder DFS visit using this node as the root. Calls the given + // visitor function at each instruction. + Status Accept(FunctionVisitor::VisitorFunction visitor_func); + + // Visits all instructions rooted at this instruction using the given visitor + // in the given order. 'order' must contain exactly the set of instructions + // rooted at this node (ie, those accessible from a DFS traversal from this + // instruction). 'order' must also be a valid topological sort of these + // instructions (defs appear before uses). + Status AcceptOrdered(DfsHloVisitor* visitor, + const std::vector& order); + + // Returns the literal associated with this instruction. + // + // Note: only constant and parameter opcodes have an associated literal. + const Literal& literal() const; + + // Returns the parameter number associated with this instruction. + // + // Note: only parameter opcodes have an associated parameter number. + int64 parameter_number() const { + CHECK_EQ(HloOpcode::kParameter, opcode_); + return parameter_number_; + } + + const string& parameter_name() const { + CHECK_EQ(HloOpcode::kParameter, opcode_); + return parameter_name_; + } + + // Returns the dimension sizes or numbers associated with this instruction. + // + // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, + // and reverse. + const std::vector& dimensions() const; + int64 dimensions(int64 index) const; + + // Accessor for the dimension in which a concatenate HLO should occur. + // Precondition: opcode() == HloOpcode::kConcatenate + int64 concatenate_dimension() const; + + // Returns the tuple index associated with this instruction. + // + // Precondition: opcode() == HloOpcode::kGetTupleElement + int64 tuple_index() const; + + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. + // The setter should only be called by HloModule or HloComputation methods. + // + // Precondition: The instruction has a valid to_apply_ field. + HloComputation* to_apply() const; + void set_to_apply(HloComputation* to_apply); + + // Returns the custom_call_target for CustomCall. + // Precondition: opcode() == HloOpcode::kCustomCall + const string& custom_call_target() const; + + // Gets/sets the while_condition or while_body HloComputation for While. The + // setters should only be called by HloModule or HloComputation methods. + // + // Precondition: The instruction is a While instruction. + HloComputation* while_condition() const; + HloComputation* while_body() const; + void set_while_condition(HloComputation* while_condition); + void set_while_body(HloComputation* while_body); + + // Gets/sets the select or scatter HloComputation for SelectAndScatter. The + // setters should only be called by HloModule or HloComputation methods. + // + // Precondition: opcode() == HloOpcode::kSelectAndScatter. + HloComputation* select() const; + HloComputation* scatter() const; + void set_select(HloComputation* select); + void set_scatter(HloComputation* scatter); + + // Returns a string for the signature of this instruction if considered as a + // function, e.g. the signature of an F32 add is (F32, F32) -> F32. + string SignatureString() const; + + // Returns a debugging string that represents this instruction. + string ToString() const; + + // As ToString, but returns a shorter string. + string ToShortString() const; + + // Returns a logging instruction, if the output of this instruction is logged. + // + // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace + HloInstruction* tracing() const; + void set_tracing(HloInstruction* trace_instruction); + + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + // + // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv + int64 channel_id() const { return channel_id_; } + void set_channel_id(int64 id) { channel_id_ = id; } + + // Returns a tag to be used in tracing. + // + // Precondition: opcode() == HloOpcode::kTrace + const string& tracing_tag() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Returns true if this instruction is fused, ie contained within a fusion + // instruction. + bool IsFused() const; + + // Returns true if this instruction can be legally fused into a fusion + // instruction. + bool IsFusable() const; + + // Returns the fusion instruction that contains this instruction. + // + // Note: only valid if this instruction is fused into a fusion instruction. + HloInstruction* fusion_instruction() const; + + // Returns the root instruction of the fused expression contained within this + // fusion instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + HloInstruction* fused_expression_root() const; + + // Returns the vector of fused instructions inside this fusion + // instruction. The order is a reverse postorder of the fused expression (root + // is first in the order). + // + // Precondition: opcode() == HloOpcode::kFusion + const std::list>& fused_instructions() const; + + // Returns the fused parameter instruction in this fusion instruction + // corresponding to the given parameter number. + // + // Precondition: opcode() == HloOpcode::kFusion + HloInstruction* fused_parameter(int64 parameter_number) const; + + FusionKind fusion_kind() const { + CHECK_EQ(HloOpcode::kFusion, opcode_); + return fusion_kind_; + } + + // Fuses the given instruction in this fusion instruction. instruction_to_fuse + // is cloned and the clone is placed in the fusion + // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather + // than moved to cleanly handle the case where the instruction has a use + // outside the fusion instruction. Moving such an instruction into a fusion + // instruction would violate the single-result invariant of HLO instructions + // and significantly complicate code generation. + // + // Precondition: this->opcode() == HloOpcode::kFusion + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); + + // Returns the start index in the given dimension for a slice node. + // + // Precondition: opcode() == HloOpcode::kSlice + int64 slice_starts(int64 dimension) const { + CHECK_EQ(HloOpcode::kSlice, opcode_); + return slice_starts_[dimension]; + } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + // + // Precondition: opcode() == HloOpcode::kSlice + int64 slice_limits(int64 dimension) const { + CHECK_EQ(HloOpcode::kSlice, opcode_); + return slice_limits_[dimension]; + } + const std::vector& slice_limits() const { + CHECK_EQ(HloOpcode::kSlice, opcode_); + return slice_limits_; + } + + // Returns the size of the slice in the given dimension for a dynamic + // slice node. + // + // Precondition: opcode() == HloOpcode::kDynamicSlice + int64 slice_sizes(int64 dimension) const { + CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); + return dynamic_slice_sizes_[dimension]; + } + const std::vector& dynamic_slice_sizes() const { + CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); + return dynamic_slice_sizes_; + } + + // Returns data on the window in a windowed operation such as + // convolution. + const Window& window() const { + CHECK(window_ != nullptr); + return *window_; + } + + // Returns the padding configuration for a pad node. + // + // Precondition: opcode() == HloOpcode::kPad + const PaddingConfig& padding_config() const { + CHECK(padding_config_ != nullptr); + return *padding_config_; + } + + // Returns data on the dimension numbers used for a convolution + // operation. + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + CHECK(convolution_dimension_numbers_ != nullptr); + return *convolution_dimension_numbers_; + } + + // Returns the random distribution for this rng node. + // + // Precondition: opcode() == HloOpcode::kRng + RandomDistribution random_distribution() const; + + // Clones the HLO instruction. The clone will have the same opcode, shape, and + // operands. After creation the clone has no uses. "this" (the instruction + // cloned from) is not changed. + std::unique_ptr Clone(); + + // Clones the HLO instruction as above but with new shape and operands. + std::unique_ptr CloneWithNewOperands( + const Shape& shape, + tensorflow::gtl::ArraySlice operands); + + // Computes and returns the computations this instruction calls (if any). This + // includes computations called by fused instructions inside of a fusion + // instruction. + std::set MakeCalledComputationsSet() const; + + // Returns true if this instruction performs an elementwise operation on + // `operand_idx`-th operand. An instruction is elementwise on an operand iff, + // after performing necessary implicit broadcast + // (cs/IrArray::EmitArrayElementAddress), to compute the output at index + // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is + // the element at {i_0,i_1,...,i_n}. + // + // Note on performance: when this instruction is kFusion, this method, in the + // worst case, scans all fused instructions. We could speed this up by + // caching. + bool IsElementwiseOnOperand(int64 operand_idx) const; + + // Returns true if this instruction is elementwise on all its operands. + bool IsElementwise() const; + + // Returns whether this instruction may reuse elements of its `i`th operand. + bool ReusesOperandElements(int64 i) const { + return OperandElementUse(i) == UseKind::kReuse; + } + + // Returns the indices that the given operand appear in the operand list of + // this instruction. Note that an instruction can use the same operand + // multiple times. + std::vector OperandIndices(const HloInstruction* operand) const; + + // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If + // this reshape merely inserts or deletes 1-sized dimensions, return the input + // indices of the deleted dimensions and the output indices of the inserted + // dimensions. + // + // Precondition: this op must be a reshape. + std::tuple, std::vector> + ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; + + // Returns a string identifier for this instruction. If no string identifier + // has been explicitly set, then the identifier is the serialized pointer to + // this instruction. + const string& name() const { return name_; } + + // Sets the string identifier for this instruction. + void set_name(const string& name) { name_ = name; } + + // Set/get the computation containing this instruction. set_parent should only + // be called by HloComputation methods which add/remove instructions to + // computations. + void set_parent(HloComputation* computation) { parent_ = computation; } + const HloComputation* parent() const { return parent_; } + HloComputation* parent() { return parent_; } + + // Returns whether we could assign input and output layouts to this + // instruction to make it a bitcast. + bool CouldBeBitcast() const; + + private: + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + + // Adds a user for this instruction. + void AddUser(HloInstruction* user); + + // Removes a user for this instruction. + void RemoveUser(HloInstruction* user); + + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + + // Clones the given instruction_to_fuse and insert the clone into this fusion + // instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse); + + // Clones a fusion instruction with a new shape and operands. + std::unique_ptr CloneFusionWithNewOperands( + const Shape& shape, + tensorflow::gtl::ArraySlice operands); + + // Inner DFS traversal function -- this function being called (rather than + // Accept above) allows us to distinguish the root of the traversal. + Status AcceptInternal(DfsHloVisitor* visitor); + + // Inner DFS traversal function called when visiting this HloInstruction. + Status AcceptInternalVisit(DfsHloVisitor* visitor); + + // CHECKs various invariants of a fusion instruction. + void CheckFusionInstruction() const; + + // Returns true if this instruction can legally have the dimensions field + // set. Used for checking precondition of dimensions field accessors. + bool CanHaveDimensionsField() const; + + // Returns how this instruction uses elements of its `i`th operand. + UseKind OperandElementUse(int64 i) const; + + // Result shape of this instruction. + Shape shape_; + + // Opcode for this instruction. + HloOpcode opcode_; + + // Literal, only present for kConstant. + std::unique_ptr literal_; + + // Constant index, only present for kGetTupleElement. + int64 tuple_index_ = 0; + + // Dimensions present for some operations that require reshaping or + // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + std::vector dimensions_; + + // Describes the window in a windowed operation such as convolution. + std::unique_ptr window_; + + // Describes the dimension numbers used for a convolution. + std::unique_ptr convolution_dimension_numbers_; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + std::vector dynamic_slice_sizes_; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. Only set for pad instructions. + std::unique_ptr padding_config_; + + // The set of instruction fused into this fusion instruction. Only set for + // fusion instructions. + std::list> fused_instructions_; + + // If this instruction is fused into a fusion instruction, this field points + // to the fusion instruction. + HloInstruction* parent_fusion_instruction_ = nullptr; + + // The vector of parameter instructions inside this fusion instruction. The + // index of the vector is the parameter_number of the parameter instruction. + // This vector is non-empty only for fusion instructions. + std::vector fused_parameters_; + + // The root of the expression fused into this fusion instruction. + HloInstruction* fused_root_ = nullptr; + + // The type of the fusion. Used by kFusion only. + FusionKind fusion_kind_; + + // For parameter instructions this field holds the parameter number. + int64 parameter_number_ = 0; + string parameter_name_; + + // Computation to apply, only present for kCall, kMap, kReduce and + // kReduceWindow. + HloComputation* to_apply_ = nullptr; + + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target_; + + // Computation for condition and body of kWhile, only present for kWhile. + HloComputation* condition_ = nullptr; + HloComputation* body_ = nullptr; + + // Computation for select and scatter, only present for + // kSelectAndScatter. + HloComputation* select_ = nullptr; + HloComputation* scatter_ = nullptr; + + // Instruction operands. + std::vector operands_; + + // The users of this instruction. Users are HLOs where this instruction is an + // operand. + std::set users_; + + // The set of control predecessors of this instruction. + std::set control_predecessors_; + + // A trace instruction that consumes this instruction. + // + // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as + // an operand. + HloInstruction* trace_instruction_ = nullptr; + + // The distribution requested for random number generation. + // Only present for kRng. + RandomDistribution distribution_; + + // Represents a unique identifier for each Send/Recv instruction pair. + // Only present for kSend or kRecv. + int64 channel_id_ = -1; + + // String identifier for instruction. + string name_; + + // The computation in which this instruction is contained. + HloComputation* parent_ = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); +}; + +string FusionKindString(HloInstruction::FusionKind kind); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc new file mode 100644 index 0000000000..41164a2d58 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -0,0 +1,894 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +#define EXPECT_ISET(A, E...) EXPECT_EQ(A, (std::set{E})) +#define EXPECT_IVEC(A, E...) EXPECT_EQ(A, (std::vector{E})) + +class HloInstructionTest : public ::testing::Test { + protected: + HloInstructionTest() {} + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); +}; + +// Simple visitor that collects the number of users and operands for certain HLO +// nodes. It also verifies some of the DFS visiting invariants (operands visited +// before their users, nodes not visited twice, etc.) +class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("not implemented %s", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + } + + Status HandleParameter(HloInstruction* parameter) override { + EXPECT_EQ(0, count_.count(parameter)); + count_[parameter] = GetCountsForNode(parameter); + return Status::OK(); + } + + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override { + EXPECT_EQ(0, count_.count(constant)); + count_[constant] = GetCountsForNode(constant); + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) override { + EXPECT_EQ(0, count_.count(add)); + EXPECT_GT(count_.count(lhs), 0); + EXPECT_GT(count_.count(rhs), 0); + count_[add] = GetCountsForNode(add); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate, + HloInstruction* operand) override { + EXPECT_EQ(0, count_.count(negate)); + EXPECT_GT(count_.count(operand), 0); + count_[negate] = GetCountsForNode(negate); + return Status::OK(); + } + + Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice operands, + HloComputation* /*function*/, + tensorflow::gtl::ArraySlice /*static_operands*/) + override { + EXPECT_EQ(0, count_.count(map)); + for (HloInstruction* arg : operands) { + EXPECT_GT(count_.count(arg), 0); + } + count_[map] = GetCountsForNode(map); + return Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override { + EXPECT_EQ(0, count_.count(reduce)); + EXPECT_GT(count_.count(arg), 0); + EXPECT_GT(count_.count(init_value), 0); + count_[reduce] = GetCountsForNode(reduce); + return Status::OK(); + } + + int64 NumOperands(const HloInstruction* node) { + auto count_iterator = count_.find(node); + EXPECT_NE(count_.end(), count_iterator); + return count_iterator->second.operand_count; + } + + int64 NumUsers(const HloInstruction* node) { + auto count_iterator = count_.find(node); + EXPECT_NE(count_.end(), count_iterator); + return count_iterator->second.user_count; + } + + private: + struct NumOpsAndUsers { + int64 operand_count; + int64 user_count; + }; + + // Helper function to count operands and users for the given HLO. + NumOpsAndUsers GetCountsForNode(const HloInstruction* node) { + NumOpsAndUsers counts{node->operand_count(), node->user_count()}; + return counts; + } + + // Counters for HLOs. Maps HLO to a NumOpsAndUsers. + std::unordered_map count_; +}; + +TEST_F(HloInstructionTest, BasicProperties) { + auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); + + EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); + EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_EQ(0, parameter->operand_count()); +} + +TEST_F(HloInstructionTest, UserWithTwoOperands) { + // [Param foo]-----> |-----| + // | Add | + // [Param bar]-----> |-----| + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(), + bar.get()); + EXPECT_MATCH(add->operands(), testing::UnorderedMatcher( + foo.get(), bar.get())); + EXPECT_ISET(foo->users(), add.get()); + EXPECT_ISET(bar->users(), add.get()); + + OpAndUserCollectingVisitor visitor; + ASSERT_IS_OK(add->Accept(&visitor)); + + EXPECT_EQ(2, visitor.NumOperands(add.get())); + EXPECT_EQ(0, visitor.NumUsers(add.get())); + EXPECT_EQ(1, visitor.NumUsers(foo.get())); + EXPECT_EQ(1, visitor.NumUsers(bar.get())); +} + +TEST_F(HloInstructionTest, MultipleUsers) { + // [Param foo] + // / | \ + // / | \ [Param bar] + // / | \ | + // | | | | + // V V V V + // ------- ------- ----------- + // | exp | | exp | | add | + // ------- ------- ----------- + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + auto exp1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get()); + auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get()); + auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(), + bar.get()); + + EXPECT_EQ(3, foo->user_count()); + EXPECT_EQ(1, bar->user_count()); + EXPECT_EQ(0, exp1->user_count()); + EXPECT_EQ(0, exp2->user_count()); + EXPECT_EQ(0, add->user_count()); + + OpAndUserCollectingVisitor visitor; + ASSERT_IS_OK(add->Accept(&visitor)); + + EXPECT_EQ(2, visitor.NumOperands(add.get())); + EXPECT_EQ(3, visitor.NumUsers(foo.get())); +} + +TEST_F(HloInstructionTest, RepeatedUser) { + // Here we have a user 'add' nodes that uses the same HLO in both operands. + // Make sure we don't count it as two distinct users. + // + // [Param foo] + // | | + // | | + // | | + // V V + // ------- + // | add | + // ------- + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(), + foo.get()); + EXPECT_EQ(1, foo->user_count()); + + // But 'add' still has two operands, even if both are the same HLO. + EXPECT_EQ(2, add->operand_count()); +} + +TEST_F(HloInstructionTest, MultipleUsersAndOperands) { + // [param0] [param1] + // | | + // | [c0] | + // | | | + // V | V + // ------- | ------- + // | add | <---^---> | add | + // ------- ------- + // | | + // \ ------- / + // ---->| add |<---- + // ------- + auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); + auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); + auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + param0.get(), c0.get()); + auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + c0.get(), param1.get()); + auto addtotal = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + addleft.get(), addright.get()); + + OpAndUserCollectingVisitor visitor; + ASSERT_IS_OK(addtotal->Accept(&visitor)); + + EXPECT_EQ(2, visitor.NumUsers(c0.get())); + EXPECT_EQ(2, visitor.NumOperands(addleft.get())); + EXPECT_EQ(2, visitor.NumOperands(addright.get())); + EXPECT_EQ(2, visitor.NumOperands(addtotal.get())); +} + +TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { + // [param0] [c0] [param1] + // | | | + // | V | + // | ------- | + // | | neg | | + // | ------- | + // V | V + // ------- | ------- + // | add | <---^---> | add | + // ------- ------- + // | | + // \ ------- / + // ---->| add |<---- + // ------- + // | + // V + // ------- + // | neg | + // ------- + auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); + auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); + auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto neg1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0.get()); + auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + param0.get(), neg1.get()); + auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + neg1.get(), param1.get()); + auto addtotal = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + addleft.get(), addright.get()); + auto neg2 = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal.get()); + + OpAndUserCollectingVisitor visitor; + ASSERT_IS_OK(neg2->Accept(&visitor)); + + EXPECT_EQ(1, visitor.NumUsers(c0.get())); + EXPECT_EQ(2, visitor.NumUsers(neg1.get())); + EXPECT_EQ(2, visitor.NumOperands(addleft.get())); + EXPECT_EQ(2, visitor.NumOperands(addright.get())); + EXPECT_EQ(2, visitor.NumOperands(addtotal.get())); + EXPECT_EQ(1, visitor.NumOperands(neg2.get())); + EXPECT_EQ(0, visitor.NumUsers(neg2.get())); +} + +TEST_F(HloInstructionTest, TrivialMap) { + // This tests creating a trivial x+1 map as the only operation. + // + // param0[100x10] ---> (map x+1) + // + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); + + // Builds an x+1.0 computation to use in a Map. + auto builder = HloComputation::Builder("f32+1"); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x")); + auto value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); + auto add_f32 = builder.Build(); + + // Builds a parameter and feeds it to the map. + auto param0 = HloInstruction::CreateParameter(1, f32a100x10, ""); + auto map = + HloInstruction::CreateMap(f32a100x10, {param0.get()}, add_f32.get()); + + OpAndUserCollectingVisitor visitor; + ASSERT_IS_OK(map->Accept(&visitor)); + + // Check counts. We aren't walking the mapper computation yet. + EXPECT_EQ(1, visitor.NumUsers(param0.get())); + EXPECT_EQ(0, visitor.NumUsers(map.get())); + EXPECT_EQ(1, visitor.NumOperands(map.get())); + + // TODO(dehnert): Add walking and counters for the wrapped computation. +} + +TEST_F(HloInstructionTest, TrivialReduce) { + // This tests creating a trivial x+y reduce as the only operation. + // + // param0[100x10] ---> (reduce x+y) + // + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape f32v100 = ShapeUtil::MakeShape(F32, {100}); + Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); + + // Builds an x+y computation to use in a Reduce. + auto builder = HloComputation::Builder("f32+f32"); + auto paramx = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x")); + auto paramy = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); + auto add_f32 = builder.Build(); + + // Builds a parameter and an initial value and feeds them to the reduce. + auto param0 = HloInstruction::CreateParameter(0, f32a100x10, ""); + auto const0 = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)); + auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto reduce = + HloInstruction::CreateReduce(f32v100, param0.get(), const0.get(), + /*dimensions_to_reduce=*/{1}, add_f32.get()); + + OpAndUserCollectingVisitor visitor; + ASSERT_IS_OK(reduce->Accept(&visitor)); + + // Check counts. We aren't walking the reducer computation. + EXPECT_EQ(1, visitor.NumUsers(param0.get())); + EXPECT_EQ(1, visitor.NumUsers(const0.get())); + EXPECT_EQ(0, visitor.NumUsers(reduce.get())); + EXPECT_EQ(2, visitor.NumOperands(reduce.get())); +} + +TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { + // Construct a graph of a few binary ops using two different + // parameters. Replace one of the parameters with the other parameter in one + // of the instructions. + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + foo.get(), bar.get()); + auto add_foofoo = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + foo.get(), foo.get()); + auto sum = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + add_foobar.get(), add_foofoo.get()); + + EXPECT_EQ(2, foo->user_count()); + EXPECT_EQ(1, bar->user_count()); + + // Replace the use of foo in add_foofoo with bar. + foo->ReplaceUseWith(add_foofoo.get(), bar.get()); + + EXPECT_EQ(1, foo->user_count()); + EXPECT_EQ(2, bar->user_count()); + + EXPECT_ISET(foo->users(), add_foobar.get()); + EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get()); + + EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get()); + EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get()); + EXPECT_IVEC(add_foofoo->operands(), bar.get(), bar.get()); +} + +TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { + // Construct a tuple containing several parameters. Replace one parameter with + // another in the tuple. + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + auto baz = HloInstruction::CreateParameter(2, r0f32_, "baz"); + + auto tuple = + HloInstruction::CreateTuple({foo.get(), bar.get(), baz.get(), foo.get()}); + auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + foo.get(), bar.get()); + + EXPECT_EQ(2, foo->user_count()); + EXPECT_ISET(foo->users(), tuple.get(), add_foobar.get()); + + // Replace the use of foo in tuple with bar. + foo->ReplaceUseWith(tuple.get(), bar.get()); + + EXPECT_ISET(foo->users(), add_foobar.get()); + + // Both uses of foo in tuple should have been replaced with bar. + EXPECT_IVEC(tuple->operands(), bar.get(), bar.get(), baz.get(), bar.get()); +} + +TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { + // Construct a couple unary instructions which use a parameter. Replace the + // use of a parameter in one of the unary ops with the other parameter. + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + + auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get()); + auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get()); + + EXPECT_EQ(2, foo->user_count()); + EXPECT_ISET(foo->users(), exp.get(), log.get()); + EXPECT_EQ(0, bar->user_count()); + + // Replace the use of foo in exp with bar. + foo->ReplaceUseWith(exp.get(), bar.get()); + + // The use of foo in log should not have been affected. + EXPECT_EQ(1, foo->user_count()); + EXPECT_ISET(foo->users(), log.get()); + EXPECT_IVEC(log->operands(), foo.get()); + + // Bar should now be used in exp. + EXPECT_EQ(1, bar->user_count()); + EXPECT_EQ(*bar->users().begin(), exp.get()); + EXPECT_EQ(1, exp->operands().size()); + EXPECT_EQ(*exp->operands().begin(), bar.get()); +} + +TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { + // Construct a simple graph of a few binary ops using two different + // parameters. Replace all uses of one of the parameters with the other + // parameter. + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + foo.get(), bar.get()); + auto add_foofoo = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + foo.get(), foo.get()); + auto sum = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + add_foobar.get(), add_foofoo.get()); + + EXPECT_EQ(2, foo->user_count()); + EXPECT_EQ(1, bar->user_count()); + + // Replace all uses of foo with bar. + foo->ReplaceAllUsesWith(bar.get()); + + EXPECT_EQ(0, foo->user_count()); + EXPECT_EQ(2, bar->user_count()); + + EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get()); +} + +TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { + // Construct a graph containing several ops (a unary, binary, and variadic) + // which use two parameters. Replace all uses of one of the parameters with + // the other parameter. + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); + + auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + foo.get(), bar.get()); + auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get()); + auto tuple = HloInstruction::CreateTuple({foo.get(), bar.get()}); + + EXPECT_EQ(3, foo->user_count()); + EXPECT_EQ(2, bar->user_count()); + + // Replace all uses of foo with bar. + foo->ReplaceAllUsesWith(bar.get()); + + EXPECT_EQ(0, foo->user_count()); + EXPECT_EQ(3, bar->user_count()); + + EXPECT_ISET(bar->users(), add_foobar.get(), exp.get(), tuple.get()); +} + +// Simple visitor that collects and post-processes each node in the graph. +class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault { + public: + NodeCollectorAndPostProcessor() {} + + Status Postprocess(HloInstruction* hlo) override { + post_processed_nodes_.push_back(hlo); + return Status::OK(); + } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + visited_nodes_.push_back(hlo_instruction); + return Status::OK(); + } + + const std::vector& visited_nodes() { + return visited_nodes_; + } + + const std::vector& post_processed_nodes() { + return post_processed_nodes_; + } + + private: + std::vector visited_nodes_; + std::vector post_processed_nodes_; +}; + +// Returns true if "vec" contains distinct nodes. +bool Distinct(const std::vector& vec) { + std::set distinct_nodes(vec.begin(), vec.end()); + return distinct_nodes.size() == vec.size(); +} + +TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { + // Verifies all the nodes are visited and post-processed in the same order, + // and that each node is visited exactly once. + // + // /--> exp --\ + // foo add + // \--> log --/ + auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo"); + auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get()); + auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get()); + auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp.get(), + log.get()); + + NodeCollectorAndPostProcessor visitor; + ASSERT_IS_OK(add->Accept(&visitor)); + // Verifies all the nodes are visited and post-processed in the same order. + EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes()); + // Verifies each node is visited exactly once. + EXPECT_TRUE(Distinct(visitor.visited_nodes())); +} + +TEST_F(HloInstructionTest, SingletonFusionOp) { + // Create a fusion instruction containing a single unary operation. + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto exp = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, exp.get()); + + EXPECT_IVEC(fusion->operands(), constant.get()); + EXPECT_ISET(constant->users(), fusion.get(), exp.get()); +} + +TEST_F(HloInstructionTest, BinaryFusionOp) { + // Create a fusion instruction containing a single binary operation. + auto constant1 = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto constant2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f)); + auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, + constant1.get(), constant2.get()); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, add.get()); + + EXPECT_IVEC(fusion->operands(), constant1.get(), constant2.get()); + EXPECT_ISET(constant1->users(), fusion.get(), add.get()); + EXPECT_ISET(constant2->users(), fusion.get(), add.get()); +} + +TEST_F(HloInstructionTest, ChainFusionOp) { + // Create a chain of fused unary ops. + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto exp1 = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); + auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get()); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, exp3.get()); + fusion->FuseInstruction(exp2.get()); + fusion->FuseInstruction(exp1.get()); + + EXPECT_IVEC(fusion->operands(), constant.get()); + EXPECT_ISET(constant->users(), fusion.get(), exp1.get()); +} + +TEST_F(HloInstructionTest, ComplexFusionOp) { + // Fuse all instructions in complicated expression: + // + // add = Add(C1, C2) + // clamp = Clamp(C2, add, add) + // exp = Exp(add) + // mul = Mul(exp, C3) + // sub = Sub(mul, clamp) + // tuple = Tuple({sub, sub, mul, C1}) + // + // Notable complexities are repeated operands in a same instruction, different + // shapes, use of value in different expressions. + auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f)); + auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f)); + + auto add = + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get()); + auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, + c2.get(), add.get(), add.get()); + auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get()); + auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, + exp.get(), c3.get()); + auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, + mul.get(), clamp.get()); + auto tuple = + HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()}); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, tuple.get()); + fusion->FuseInstruction(sub.get()); + fusion->FuseInstruction(mul.get()); + fusion->FuseInstruction(exp.get()); + fusion->FuseInstruction(clamp.get()); + fusion->FuseInstruction(add.get()); + + // Operands in the fusion instruction's operands() vector should be in the + // order in which their users were added fused. + EXPECT_IVEC(fusion->operands(), c1.get(), c3.get(), c2.get()); + EXPECT_ISET(c1->users(), add.get(), tuple.get(), fusion.get()); +} + +// Convenience function for comparing two HloInstructions inside of +// std::unique_ptrs. +static bool Identical(std::unique_ptr instruction1, + std::unique_ptr instruction2) { + // Verify Identical is reflexive for both instructions. + EXPECT_TRUE(instruction1->Identical(*instruction1)); + EXPECT_TRUE(instruction2->Identical(*instruction2)); + + bool is_equal = instruction1->Identical(*instruction2); + // Verify Identical is symmetric. + EXPECT_EQ(is_equal, instruction2->Identical(*instruction1)); + return is_equal; +} + +TEST_F(HloInstructionTest, IdenticalInstructions) { + // Test HloInstruction::Identical with some subset of instructions types. + + // Create a set of random constant operands to use below. Make them matrices + // so dimensions are interesting. + auto operand1 = HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + auto operand2 = HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0, 123.0})); + Shape shape = operand1->shape(); + + // Convenient short names for the operands. + HloInstruction* op1 = operand1.get(); + HloInstruction* op2 = operand2.get(); + + // Operations which only depend on their operands and opcode. + EXPECT_TRUE( + Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1))); + EXPECT_FALSE( + Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2))); + EXPECT_FALSE( + Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1))); + + // Tuples. + EXPECT_TRUE(Identical(HloInstruction::CreateTuple({op1, op2}), + HloInstruction::CreateTuple({op1, op2}))); + EXPECT_FALSE(Identical(HloInstruction::CreateTuple({op1, op2}), + HloInstruction::CreateTuple({op2, op1}))); + + // Broadcasts. + EXPECT_TRUE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}), + HloInstruction::CreateBroadcast(shape, op1, {0, 1}))); + EXPECT_FALSE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}), + HloInstruction::CreateBroadcast(shape, op1, {1, 0}))); + Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42}); + Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123}); + EXPECT_FALSE( + Identical(HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}), + HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1}))); + + // Binary operands. + EXPECT_TRUE(Identical( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2))); + EXPECT_FALSE(Identical( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1))); + EXPECT_FALSE(Identical( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); +} + +TEST_F(HloInstructionTest, FunctionVisitor) { + // Verify the function visitor HloInstruction::Accept visits all instructions + // from a root properly given the following graph: + // + // param + // / \ + // negate exp + // \ / + // add + const Shape f32 = ShapeUtil::MakeShape(F32, {}); + auto param = HloInstruction::CreateParameter(0, f32, "0"); + auto negate = + HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param.get()); + auto exp = HloInstruction::CreateUnary(f32, HloOpcode::kExp, param.get()); + auto add = HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate.get(), + exp.get()); + + int visit_num = 0; + std::unordered_map visit_order; + EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) { + EXPECT_EQ(0, visit_order.count(inst)); + visit_order[inst] = visit_num; + visit_num++; + return Status::OK(); + })); + + EXPECT_EQ(0, visit_order.at(param.get())); + // negate and exp can be visited in an arbitrary order. + EXPECT_TRUE(visit_order.at(exp.get()) == 1 || visit_order.at(exp.get()) == 2); + EXPECT_TRUE(visit_order.at(negate.get()) == 1 || + visit_order.at(negate.get()) == 2); + EXPECT_NE(visit_order.at(exp.get()), visit_order.at(negate.get())); + EXPECT_EQ(3, visit_order.at(add.get())); +} + +TEST_F(HloInstructionTest, FullyElementwise) { + const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + auto x = HloInstruction::CreateParameter(0, r1f32, "x"); + auto y = HloInstruction::CreateParameter(1, r1f32, "y"); + auto add = + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x.get(), y.get()); + EXPECT_TRUE(add->IsElementwise()); + for (int i = 0; i < add->operand_count(); ++i) { + EXPECT_TRUE(add->IsElementwiseOnOperand(i)); + } +} + +TEST_F(HloInstructionTest, PartiallyElementwise) { + const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); + + // Fused expression: + // + // p0 p1 p2 p3 + // \ / / | + // mul / | + // \ / | + // div broadcast + // \ / + // max + // + // The fusion instruction is not elementwise on p3 because the broadcast is + // not elementwise. + HloComputation::Builder builder("PartiallyElementwise"); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1")); + HloInstruction* p2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2")); + HloInstruction* p3 = + builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3")); + HloInstruction* mul = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1)); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2)); + // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5]. + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1})); + HloInstruction* max = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); + + auto computation = builder.Build(); + HloInstruction* fusion = computation->CreateFusionInstruction( + {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); + EXPECT_FALSE(fusion->IsElementwise()); + for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); + ++operand_idx) { + const HloInstruction* operand = fusion->operand(operand_idx); + if (operand == p3) { + EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); + } else { + EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); + } + } +} + +TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { + // Fused expression: + // + // x y + // \ / \ + // min broadcast + // \ / + // sub + // + // The fusion instruction is elementwise on `x` because the only path from x + // to sub contains only elementwise operations. It is not elementwise on `y` + // because the path y->broadcast->sub is not all elementwise. + const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + + HloComputation::Builder builder("PartiallyElementwiseWithReuse"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y")); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y)); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0})); + HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, min, broadcast)); + + auto computation = builder.Build(); + HloInstruction* fusion = computation->CreateFusionInstruction( + {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); + EXPECT_FALSE(fusion->IsElementwise()); + for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); + ++operand_idx) { + if (fusion->operand(operand_idx) == x) { + EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); + } else { + EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); + } + } +} + +TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + // + // Tests that shapes aren't mangled by Clone(). + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + + auto computation = builder.Build(); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); + + auto fusion2 = fusion->Clone(); + const HloInstruction* root = fusion->fused_expression_root(); + const HloInstruction* root2 = fusion2->fused_expression_root(); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape())); + EXPECT_TRUE( + ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape())); + EXPECT_TRUE( + ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape())); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(), + root2->operand(1)->operand(0)->shape())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc new file mode 100644 index 0000000000..35110f3946 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -0,0 +1,269 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +HloComputation* HloModule::AddEntryComputation( + std::unique_ptr computation) { + CHECK_EQ(nullptr, entry_computation_); + entry_computation_ = computation.get(); + computation->set_parent(this); + computations_.push_back(std::move(computation)); + return computations_.back().get(); +} + +HloComputation* HloModule::AddEmbeddedComputation( + std::unique_ptr computation) { + computation->set_parent(this); + computations_.push_back(std::move(computation)); + return computations_.back().get(); +} + +void HloModule::ReplaceComputations( + const std::unordered_map& replacements) { + // Replace all uses of non-canonical computations with their + // representatives. + std::vector> new_computations; + new_computations.reserve(computations_.size()); + + for (std::unique_ptr& computation : computations_) { + for (auto& instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + HloComputation* new_arg = tensorflow::gtl::FindWithDefault( + replacements, instruction->to_apply(), nullptr); + if (new_arg != nullptr) { + instruction->set_to_apply(new_arg); + } + break; + } + case HloOpcode::kWhile: { + HloComputation* new_condition = tensorflow::gtl::FindWithDefault( + replacements, instruction->while_condition(), nullptr); + if (new_condition != nullptr) { + instruction->set_while_condition(new_condition); + } + HloComputation* new_body = tensorflow::gtl::FindWithDefault( + replacements, instruction->while_body(), nullptr); + if (new_body != nullptr) { + instruction->set_while_body(new_body); + } + break; + } + case HloOpcode::kSelectAndScatter: { + HloComputation* new_select = tensorflow::gtl::FindWithDefault( + replacements, instruction->select(), nullptr); + if (new_select != nullptr) { + instruction->set_select(new_select); + } + HloComputation* new_scatter = tensorflow::gtl::FindWithDefault( + replacements, instruction->scatter(), nullptr); + if (new_scatter != nullptr) { + instruction->set_scatter(new_scatter); + } + break; + } + default: + break; + } + } + + if (replacements.find(computation.get()) == replacements.end()) { + new_computations.push_back(std::move(computation)); + } + } + + // Replace entry_computation if necessary. + entry_computation_ = tensorflow::gtl::FindWithDefault( + replacements, entry_computation_, entry_computation_); + + computations_ = std::move(new_computations); +} + +string HloModule::ToString() const { + std::ostringstream s; + s << "HloModule " << name() << ":\n\n"; + s << "ENTRY " << entry_computation()->ToString() << "\n\n"; + for (const std::unique_ptr& computation : computations_) { + if (computation.get() != entry_computation()) { + s << computation->ToString() << "\n\n"; + } + } + return s.str(); +} + +namespace { +// Returns whether `hlo` is used outside the given subcomputation. +// `instructions_in_subcomputation` is the instruction set of the given +// subcomputation. +bool IsUsedOutsideSubcomputation( + const HloInstruction& hlo, + const std::unordered_set& instructions_in_subcomputation) { + for (HloInstruction* user : hlo.users()) { + if (!instructions_in_subcomputation.count(user)) { + return true; + } + } + return false; +} +} // anonymous namespace + +HloInstruction* HloModule::OutlineExpressionFromComputation( + tensorflow::gtl::ArraySlice instructions_to_outline, + const string& outlined_computation_name, HloComputation* computation) { + auto builder = HloComputation::Builder(outlined_computation_name); + + // A map from original instructions to their counterparts in the new outlined + // function. + std::unordered_map outlined_instructions; + // A set that contains all instructions to be outlined. + std::unordered_set instruction_set_to_outline( + instructions_to_outline.begin(), instructions_to_outline.end()); + std::vector arguments; + std::vector outputs; + int64 parameter_count = 0; + for (HloInstruction* instruction_to_outline : instructions_to_outline) { + // Clone the original instruction. + HloInstruction* outlined_instruction = + builder.AddInstruction(instruction_to_outline->Clone()); + + // Replace its operands to their counterparts in the new function. + for (int64 operand_num = 0; + operand_num < outlined_instruction->operand_count(); ++operand_num) { + HloInstruction* old_operand = + outlined_instruction->mutable_operand(operand_num); + + HloInstruction** operand_slot = &(outlined_instructions[old_operand]); + if (*operand_slot == nullptr) { + // Because instructions_to_outline is in topological order, if + // old_operand is not in outlined_instructions, old_operand must be an + // input of the outlined subcomputation and thus should be represented + // as a parameter in the new function. + arguments.push_back(old_operand); + *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( + parameter_count, old_operand->shape(), "")); + ++parameter_count; + } + outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot); + } + + // Insert the new instruction into the outlined_instructions map. + InsertOrDie(&outlined_instructions, instruction_to_outline, + outlined_instruction); + + // Mark instruction_to_outline an output if it is used outside the + // subcomputation or is the output of the original computation (i.e. used + // externally). + if (instruction_to_outline->user_count() == 0 || + IsUsedOutsideSubcomputation(*instruction_to_outline, + instruction_set_to_outline)) { + outputs.push_back(instruction_to_outline); + } + } + + if (outputs.size() != 1) { + string error_message = + "The subcomputation to outline has multiple outputs:\n"; + for (HloInstruction* output : outputs) { + tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); + } + LOG(FATAL) << error_message; + } + HloInstruction* output = outputs[0]; + + // Creates a call to the nested computation. + HloComputation* nested_computation = AddEmbeddedComputation( + builder.Build(FindOrDie(outlined_instructions, output))); + HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall( + output->shape(), arguments, nested_computation)); + + VLOG(2) << "Outlining the following instructions"; + for (auto* instruction_to_outline : instructions_to_outline) { + VLOG(2) << " " << instruction_to_outline->ToString(); + } + VLOG(2) << "as a call " << call->ToString(); + VLOG(2) << "to " << nested_computation->ToString(); + + computation->ReplaceUsesOfInstruction(output, call); + for (auto i = instructions_to_outline.rbegin(); + i != instructions_to_outline.rend(); ++i) { + computation->RemoveInstruction(*i); + } + + return call; +} + +std::list HloModule::MakeComputationPostOrder() const { + // First determine all root computations by building a set of nonroot + // computations (computations which are called by an instruction in the + // module). + std::set nonroot_computations; + for (auto& computation : computations_) { + for (auto& instruction : computation->instructions()) { + for (auto called_computation : instruction->MakeCalledComputationsSet()) { + nonroot_computations.insert(called_computation); + } + } + } + + // Keep track of computations which have already been added to the post + // order. This prevents duplication as an embedded computation may be called + // from two different root computations. + std::set added_computations; + std::list post_order; + for (auto& computation : computations_) { + if (nonroot_computations.count(computation.get()) == 0) { + for (HloComputation* embedded_computation : + computation->MakeEmbeddedComputationsList()) { + if (added_computations.count(embedded_computation) == 0) { + post_order.push_back(embedded_computation); + added_computations.insert(embedded_computation); + } + } + // Root computations should only be encountered once. + CHECK_EQ(0, added_computations.count(computation.get())); + post_order.push_back(computation.get()); + added_computations.insert(computation.get()); + } + } + CHECK_EQ(post_order.size(), computations_.size()); + return post_order; +} + +uint64 HloModule::RandomNew64() const { + tensorflow::mutex_lock l(rng_mutex_); + return rng_(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h new file mode 100644 index 0000000000..d598750da6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -0,0 +1,132 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { + +// Describes a compilation unit at the HLO level. +// +// A HLO module contains one or more HLO computations. The module contains one +// "entry" computation which produces the result. The module also includes any +// embedded computations used by instructions such as "map" and "reduce". All +// computations are owned by the module. +class HloModule { + public: + explicit HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle) + : name_(name), + entry_computation_(nullptr), + has_entry_computation_handle_(true), + entry_computation_handle_(entry_computation_handle) {} + + // Constructor without a versioned computation handle. This constructor should + // only be used for HloModules used outside of the XLA service (eg + // tests). The versioned handle is used by the service in the compilation + // cache. + explicit HloModule(const string& name) + : name_(name), entry_computation_(nullptr) {} + + // Adds an entry computation to the module. A module can only have one entry + // computation. Returns a pointer to the newly added computation. + HloComputation* AddEntryComputation( + std::unique_ptr computation); + + // Adds an embedded computation to the module. + HloComputation* AddEmbeddedComputation( + std::unique_ptr computation); + + // Replaces all uses of computations that are keys of 'replacements' with + // the corresponding values in 'replacements'. Replaces the entry computation, + // if applicable. + // + // This function iterates over all instructions in the module to find + // computations to replace. We could speed it up by keeping track of users of + // computations. + void ReplaceComputations( + const std::unordered_map& replacements); + + const string& name() const { return name_; } + + // Return a pointer to the entry computation of the module.. + HloComputation* entry_computation() const { + CHECK_NE(nullptr, entry_computation_); + return entry_computation_; + } + + const VersionedComputationHandle& entry_computation_handle() const { + return entry_computation_handle_; + } + + const std::vector>& computations() const { + return computations_; + } + + // Compute and return a post order of all computations in the module. The sort + // is defined like so: if computation A has an instruction which calls + // computation B, then A will appear after B in the sort. + std::list MakeComputationPostOrder() const; + + string ToString() const; + + // Outlines the given expression from the given computation. + // instructions_to_outline contains the instructions that form the expression. + // + // Precondition: instructions in instructions_to_outline are in topological + // order (root of outlined instructions last). TODO(jingyue): takes a set of + // instructions and topologically sorts them. + HloInstruction* OutlineExpressionFromComputation( + tensorflow::gtl::ArraySlice instructions_to_outline, + const string& outlined_computation_name, HloComputation* computation); + + // Returns a randomly generated uint64. + uint64 RandomNew64() const; + + private: + const string name_; + HloComputation* entry_computation_; + std::vector> computations_; + + // Random number generator engine to use when generating random numbers per + // HloModule compilation. + // TODO(b/25995601): Replace with better seed setting or dev/random for + // where we don't need deterministic execution. + mutable std::mt19937_64 rng_{42}; + mutable tensorflow::mutex rng_mutex_; + + // Versioned handle of the entry computation of the module. + bool has_entry_computation_handle_ = false; + VersionedComputationHandle entry_computation_handle_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc new file mode 100644 index 0000000000..033b129e34 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +#include + +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) + : entry_computation_layout_(program_shape) {} + +string HloModuleConfig::compilation_cache_key() const { + string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_, + "::", "hybrid=", has_hybrid_result_); + tensorflow::strings::StrAppend(&key, "::("); + std::vector params; + for (const ShapeLayout& param_layout : + entry_computation_layout_.parameter_layouts()) { + params.push_back(param_layout.shape().SerializeAsString()); + } + tensorflow::strings::StrAppend( + &key, tensorflow::str_util::Join(params, ", "), ") => ", + entry_computation_layout_.result_shape().SerializeAsString()); + if (seed_ != 0) { + // TODO(b/32083678): force recompilation to reset global state. + static int counter = 0; + tensorflow::strings::StrAppend(&key, "forcing recompile ", counter++); + } + if (replica_count() != 1) { + tensorflow::strings::StrAppend(&key, "::replica_count=", replica_count()); + } + return key; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h new file mode 100644 index 0000000000..f081790869 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ + +#include + +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// This class gathers all settings and values which affect the compiled +// executable outside of the HLO code itself. This include layouts of inputs and +// outputs to the module and settings such as HLO profiling. Together the +// HloModule and HloModuleConfig unambiguously determine a particular +// executable. +class HloModuleConfig { + public: + explicit HloModuleConfig(const ProgramShape& program_shape); + + // Return a reference to the layout of the entry computation. + const ComputationLayout& entry_computation_layout() const { + return entry_computation_layout_; + } + ComputationLayout* mutable_entry_computation_layout() { + return &entry_computation_layout_; + } + + // Sets/returns whether to enable HLO-level profiling. + bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } + void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } + + bool has_hybrid_result() const { return has_hybrid_result_; } + void set_has_hybrid_result(bool has_hybrid_result) { + has_hybrid_result_ = has_hybrid_result; + } + + // Sets/returns the module seed set during execution. + void set_seed(uint64 seed) { seed_ = seed; } + uint64 seed() const { return seed_; } + + void set_replica_count(int64 replica_count) { + replica_count_ = replica_count; + } + int64 replica_count() const { return replica_count_; } + + // Return a string which unambiguously represents all the fields of this data + // structure. Used for generating a cache key for storing the compiled + // executable. + string compilation_cache_key() const; + + private: + ComputationLayout entry_computation_layout_; + + // Whether to enable HLO-level profiling. + bool hlo_profiling_enabled_ = false; + + // If this flag is true, the generated executable will return a ShapedBuffer + // holding the result of the computation. In a ShapedBuffer, tuples have their + // structure held in host memory and the element arrays (leaves of the tuple + // structure) stored in device memory. The ShapedBuffer is considered "hybrid" + // because its leaves are on device but its structure is stored on + // host. Otherwise, if this flag is false, the generated executable will + // return a DeviceMemoryBase where the result is held entirely in device + // memory. + bool has_hybrid_result_ = false; + + // Module/graph-level seed handle. + uint64 seed_ = 0; + + // The number of replicas to compile this binary for. + int64 replica_count_ = 1; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc new file mode 100644 index 0000000000..0f4252522d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace { + +class HloModuleTest : public HloTestBase { + protected: + HloModuleTest() {} + + // Create a computation which returns a constant. + std::unique_ptr CreateConstantComputation() { + auto builder = HloComputation::Builder("Constant"); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + return builder.Build(); + } + + // Creates a computation which calls the given zero-parameter computations. + std::unique_ptr CreateCallComputation( + tensorflow::gtl::ArraySlice computations) { + auto builder = HloComputation::Builder("Call"); + for (auto computation : computations) { + builder.AddInstruction( + HloInstruction::CreateCall(r0f32_, {}, computation)); + } + return builder.Build(); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(HloModuleTest, OneComputationPostOrder) { + // Create a module with a single computation. + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(CreateConstantComputation()); + + EXPECT_EQ(module->MakeComputationPostOrder().front(), computation); +} + +TEST_F(HloModuleTest, TwoComputationsPostOrder) { + // Create a module with two unconnected computations. + auto module = MakeUnique(TestName()); + auto computation1 = module->AddEntryComputation(CreateConstantComputation()); + auto computation2 = + module->AddEmbeddedComputation(CreateConstantComputation()); + + EXPECT_MATCH( + testing::ListToVec(module->MakeComputationPostOrder()), + testing::UnorderedMatcher(computation1, computation2)); +} + +TEST_F(HloModuleTest, DiamondComputationsPostOrder) { + // Create a module with a diamond call graph of computations. + auto module = MakeUnique(TestName()); + auto computation1 = + module->AddEmbeddedComputation(CreateConstantComputation()); + auto computation2 = + module->AddEmbeddedComputation(CreateCallComputation({computation1})); + auto computation3 = + module->AddEmbeddedComputation(CreateCallComputation({computation1})); + auto computation4 = module->AddEntryComputation( + CreateCallComputation({computation2, computation3})); + + auto post_order = module->MakeComputationPostOrder(); + EXPECT_MATCH(testing::ListToVec(post_order), + testing::UnorderedMatcher( + computation1, computation2, computation3, computation4)); + EXPECT_EQ(post_order.back(), computation4); + EXPECT_EQ(post_order.front(), computation1); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc new file mode 100644 index 0000000000..2ae3858163 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +string HloOpcodeString(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kAbs: + return "abs"; + case HloOpcode::kAdd: + return "add"; + case HloOpcode::kBitcast: + return "bitcast"; + case HloOpcode::kBroadcast: + return "broadcast"; + case HloOpcode::kCall: + return "call"; + case HloOpcode::kClamp: + return "clamp"; + case HloOpcode::kConcatenate: + return "concatenate"; + case HloOpcode::kConstant: + return "constant"; + case HloOpcode::kConvert: + return "convert"; + case HloOpcode::kConvolution: + return "convolution"; + case HloOpcode::kCrossReplicaSum: + return "cross-replica-sum"; + case HloOpcode::kCustomCall: + return "custom-call"; + case HloOpcode::kCopy: + return "copy"; + case HloOpcode::kDivide: + return "divide"; + case HloOpcode::kDot: + return "dot"; + case HloOpcode::kDynamicSlice: + return "dynamic-slice"; + case HloOpcode::kDynamicUpdateSlice: + return "dynamic-update-slice"; + case HloOpcode::kEq: + return "equal-to"; + case HloOpcode::kExp: + return "exponential"; + case HloOpcode::kFloor: + return "floor"; + case HloOpcode::kCeil: + return "ceil"; + case HloOpcode::kFusion: + return "fusion"; + case HloOpcode::kGe: + return "greater-than-or-equal-to"; + case HloOpcode::kGetTupleElement: + return "get-tuple-element"; + case HloOpcode::kGt: + return "greater-than"; + case HloOpcode::kIndex: + return "index"; + case HloOpcode::kInfeed: + return "infeed"; + case HloOpcode::kLe: + return "less-than-or-equal-to"; + case HloOpcode::kLog: + return "log"; + case HloOpcode::kLogicalAnd: + return "logical-and"; + case HloOpcode::kLogicalOr: + return "logical-or"; + case HloOpcode::kLogicalNot: + return "logical-not"; + case HloOpcode::kLt: + return "less-than"; + case HloOpcode::kMap: + return "map"; + case HloOpcode::kMaximum: + return "maximum"; + case HloOpcode::kMinimum: + return "minimum"; + case HloOpcode::kMultiply: + return "multiply"; + case HloOpcode::kNe: + return "not-equal-to"; + case HloOpcode::kNegate: + return "negate"; + case HloOpcode::kPad: + return "pad"; + case HloOpcode::kParameter: + return "parameter"; + case HloOpcode::kPower: + return "power"; + case HloOpcode::kRecv: + return "recv"; + case HloOpcode::kReduce: + return "reduce"; + case HloOpcode::kReduceWindow: + return "reduce-window"; + case HloOpcode::kRemainder: + return "remainder"; + case HloOpcode::kReshape: + return "reshape"; + case HloOpcode::kReverse: + return "reverse"; + case HloOpcode::kRng: + return "rng"; + case HloOpcode::kSelectAndScatter: + return "select-and-scatter"; + case HloOpcode::kSelect: + return "select"; + case HloOpcode::kSend: + return "send"; + case HloOpcode::kSign: + return "sign"; + case HloOpcode::kSlice: + return "slice"; + case HloOpcode::kSort: + return "sort"; + case HloOpcode::kSubtract: + return "subtract"; + case HloOpcode::kTanh: + return "tanh"; + case HloOpcode::kTrace: + return "trace"; + case HloOpcode::kTranspose: + return "transpose"; + case HloOpcode::kTuple: + return "tuple"; + case HloOpcode::kUpdate: + return "update"; + case HloOpcode::kWhile: + return "while"; + } +} + +bool HloOpcodeIsComparison(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kEq: + case HloOpcode::kNe: + return true; + default: + return false; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h new file mode 100644 index 0000000000..a8631b9c7d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ + +#include +#include +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// High-level optimizer instruction opcodes -- these are linear-algebra level +// opcodes. They are a flattened form of the UnaryOp, BinaryOp, ... opcodes +// present in the XLA service protobuf. +// +// See the XLA documentation for the semantics of each opcode. +enum class HloOpcode { + kAbs, + kAdd, + kBitcast, + kBroadcast, + kCall, + kCeil, + kClamp, + kConcatenate, + kConstant, + kConvert, + kConvolution, + kCopy, + kCrossReplicaSum, + kCustomCall, + kDivide, + kDot, + kDynamicSlice, + kDynamicUpdateSlice, + kEq, + kExp, + kFloor, + kFusion, + kGe, + kGetTupleElement, + kGt, + kIndex, + kInfeed, + kLe, + kLog, + kLogicalAnd, + kLogicalNot, + kLogicalOr, + kLt, + kMap, + kMaximum, + kMinimum, + kMultiply, + kNe, + kNegate, + kPad, + kParameter, + kPower, + kRecv, + kReduce, + kReduceWindow, + kRemainder, + kReshape, + kReverse, + kRng, + kSelect, + kSelectAndScatter, + kSend, + kSign, + kSlice, + kSort, + kSubtract, + kTanh, + kTrace, + kTranspose, + kTuple, + kUpdate, + kWhile, +}; + +// Returns a string representation of the opcode. +string HloOpcodeString(HloOpcode opcode); + +inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { + return os << HloOpcodeString(opcode); +} + +// Returns true iff the given opcode is a comparison operation. +bool HloOpcodeIsComparison(HloOpcode opcode); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc new file mode 100644 index 0000000000..0b64c16fdc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +// This test verifies that an example HloOpcode stringifies as expected. +TEST(HloOpcodeTest, StringifyMultiply) { + ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass.h b/tensorflow/compiler/xla/service/hlo_pass.h new file mode 100644 index 0000000000..c83d0eed00 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass.h @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Base class for HLO passes. These are used with the HloPassPipeline to +// organize a sequence of passes. +class HloPass { + public: + explicit HloPass(const string& name) : name_(name) {} + virtual ~HloPass() {} + + const string& name() const { return name_; } + + // Run the pass on the given HLO module. Return whether it modified the + // module. + virtual StatusOr Run(HloModule* module) = 0; + + private: + const string name_; + + TF_DISALLOW_COPY_AND_ASSIGN(HloPass); +}; + +// Do an HLO pass to a fix point. +template +class HloPassFix : public Pass { + public: + template + explicit HloPassFix(Args&&... args) : Pass(args...) {} + + StatusOr Run(HloModule* module) override { + bool changed = false; + bool changed_this_iteration = true; + while (changed_this_iteration) { + TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); + changed |= changed_this_iteration; + } + return changed; + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc new file mode 100644 index 0000000000..fcfd242a85 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" + +#include + +#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr HloPassPipeline::Run(HloModule* module) { + legacy_flags::HloPassPipelineFlags* flags = + legacy_flags::GetHloPassPipelineFlags(); + std::vector tmp = + tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); + tensorflow::gtl::FlatSet disabled_passes(tmp.begin(), tmp.end()); + + string prefix = name() + ": pipeline start"; + bool changed = false; + string message; + for (auto& pass : passes_) { + if (!disabled_passes.empty() && disabled_passes.count(pass->name()) > 0) { + continue; + } + + // Emit label containing: "after foo-pass, before bar-pass". + message.clear(); + tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name()); + dumper_(*module, message); + + VLOG(2) << "HLO " << message << ":"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); + + changed |= changed_this_pass; + prefix.clear(); + tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name()); + } + dumper_(*module, prefix + ", pipeline end"); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h new file mode 100644 index 0000000000..f49eed8cba --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Pipeline of HLO passes. +class HloPassPipeline : public HloPass { + public: + explicit HloPassPipeline(const string& name, + const Compiler::HloDumper& dumper) + : HloPass(name), dumper_(dumper) {} + + // Add a pass to the pipeline. It should be called with the arguments for the + // pass constructor: + // + // pipeline.AddPass(constructor_arg1, constructor_arg2); + // + // Returns a reference to the added pass. + template + T& AddPass(Args&&... args) { + auto pass = new T(std::forward(args)...); + passes_.push_back(std::unique_ptr(pass)); + return *pass; + } + + // Run all passes on the given HLO module. + StatusOr Run(HloModule* module) override; + + private: + Compiler::HloDumper dumper_; + std::vector> passes_; + + TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc new file mode 100644 index 0000000000..1556d1772f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_query.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { +namespace hlo_query { + +bool IsConstantR0F32(HloInstruction* instruction, float* out) { + if (instruction->opcode() == HloOpcode::kConstant && + ShapeUtil::IsScalarF32(instruction->shape())) { + *out = LiteralUtil::Get(instruction->literal(), {}); + return true; + } + + return false; +} + +bool AllOperandsAreParameters(const HloInstruction& instruction) { + for (const auto& operand : instruction.operands()) { + if (operand->opcode() != HloOpcode::kParameter) { + return false; + } + } + return true; +} + +HloInstruction* GetMatchingOperand( + std::function matcher, + HloInstruction* instruction) { + for (HloInstruction* op : instruction->operands()) { + if (matcher(op)) { + return op; + } + } + return nullptr; +} + +bool MatchBinaryInstructionOperand( + std::function matcher, + HloInstruction* instruction, HloInstruction** matching_operand, + HloInstruction** other_operand) { + CHECK_EQ(instruction->operand_count(), 2); + if (matcher(instruction->operand(0))) { + *matching_operand = instruction->mutable_operand(0); + *other_operand = instruction->mutable_operand(1); + return true; + } + if (matcher(instruction->operand(1))) { + *matching_operand = instruction->mutable_operand(1); + *other_operand = instruction->mutable_operand(0); + return true; + } + return false; +} + +bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode, + HloInstruction* instruction, + HloInstruction** matching_operand, + HloInstruction** other_operand) { + return MatchBinaryInstructionOperand( + [opcode](const HloInstruction* instruction) { + return instruction->opcode() == opcode; + }, + instruction, matching_operand, other_operand); +} + +bool IsScalarConstant(const HloInstruction* instruction) { + return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape()); +} + +} // namespace hlo_query +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h new file mode 100644 index 0000000000..864f892e92 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// Helper interface for making queries about the HLO IR. +namespace hlo_query { + +// Returns whether the instruction provided is a constant rank-0 float32, and +// if so, places the constant value into out. +// Precondition: out != nullptr +bool IsConstantR0F32(HloInstruction* instruction, float* out); + +// Returns whether all of an instruction's operands are parameters. +bool AllOperandsAreParameters(const HloInstruction& instruction); + +// Returns whether the instruction is a scalar constant. +bool IsScalarConstant(const HloInstruction* instruction); + +// Returns an operand of an instruction with the given opcode. If there are +// multiple matching operands, then the first matching operand is returned. If +// there are no matching operands then nullptr is returned. +HloInstruction* GetMatchingOperand( + std::function matcher, + HloInstruction* instruction); + +// Returns whether a binary instruction has a matching operand. Sets +// matching_operand to the matching operand and the other operand to +// other_operand. Note: in the case where both operands match, the first operand +// of the instruction is returned. +bool MatchBinaryInstructionOperand( + std::function matcher, + HloInstruction* instruction, HloInstruction** matching_operand, + HloInstruction** other_operand); + +// Returns whether a binary instruction has a operand with a given opcode. +// This is a special case of MatchingBinaryInstructionOperand. +bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode, + HloInstruction* instruction, + HloInstruction** matching_operand, + HloInstruction** other_operand); + +} // namespace hlo_query +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc new file mode 100644 index 0000000000..460dc5cf64 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" + +#include + +namespace xla { + +StatusOr HloSubcomputationUnification::Run(HloModule* module) { + // For each computation C in the module, find the first computation C0 in the + // computations_ list that is identical to C, and adds canon[C] = C0. + std::unordered_map canon; + const auto& computations = module->computations(); + for (auto i = computations.begin(); i != computations.end(); ++i) { + for (auto j = computations.begin(); j < i; ++j) { + // Do not waste time comparing `*i` with `*j` if `*j` is not canonical. + if (canon.find(j->get()) == canon.end() && **i == **j) { + canon[i->get()] = j->get(); + break; + } + } + } + + if (canon.empty()) { + return false; + } + + module->ReplaceComputations(canon); + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h new file mode 100644 index 0000000000..9ac3d5702d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// Unify subcomputations of a `HloModule`: if any computations are equal, choose +// one arbitrarily to use and delete the others. +class HloSubcomputationUnification : public HloPass { + public: + HloSubcomputationUnification() : HloPass("subcomputation unification") {} + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc new file mode 100644 index 0000000000..14800b5342 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { + +class HloSubcomputationUnificationTest : public HloTestBase { + protected: + HloSubcomputationUnificationTest() {} + + std::unique_ptr CreateR0S32IdentityComputation() { + auto builder = HloComputation::Builder("Identity"); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x")); + return builder.Build(); + } + + std::unique_ptr CreateR0S32AdditionComputation() { + auto builder = HloComputation::Builder("Addition"); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x")); + auto y = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0s32_, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); + return builder.Build(); + } + + std::unique_ptr CreateR1S32AdditionComputation( + const Shape& shape) { + auto builder = HloComputation::Builder("Addition"); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + auto y = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y)); + return builder.Build(); + } + + Shape r0s32_ = ShapeUtil::MakeShape(S32, {}); + Shape r0f32_ = ShapeUtil::MakeShape(S32, {}); + Shape r1s32_5_ = ShapeUtil::MakeShape(S32, {5}); + Shape r1s32_3_ = ShapeUtil::MakeShape(S32, {3}); +}; + +TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { + auto hlo_module = MakeUnique("test_module"); + auto builder = HloComputation::Builder(TestName()); + + auto callee1 = + hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); + auto callee2 = + hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); + + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + auto x = builder.AddInstruction( + HloInstruction::CreateCall(r0s32_, {constant}, callee1)); + auto y = builder.AddInstruction( + HloInstruction::CreateCall(r0s32_, {constant}, callee2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); + + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_NE(x->to_apply(), y->to_apply()); + if (VLOG_IS_ON(1)) { + hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + "before unification", false, false, nullptr); + } + EXPECT_TRUE( + HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie()); + if (VLOG_IS_ON(1)) { + hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + "after unification", false, false, nullptr); + } + EXPECT_EQ(2, hlo_module->computations().size()); + EXPECT_EQ(x->to_apply(), y->to_apply()); +} + +TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { + auto hlo_module = MakeUnique("test_module"); + auto builder = HloComputation::Builder(TestName()); + + auto callee1 = + hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); + auto callee2 = + hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); + + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + auto x = builder.AddInstruction( + HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); + auto y = builder.AddInstruction( + HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); + + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_NE(x->to_apply(), y->to_apply()); + if (VLOG_IS_ON(1)) { + hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + "before unification", false, false, nullptr); + } + EXPECT_TRUE( + HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie()); + if (VLOG_IS_ON(1)) { + hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + "after unification", false, false, nullptr); + } + EXPECT_EQ(2, hlo_module->computations().size()); + EXPECT_EQ(x->to_apply(), y->to_apply()); +} + +// Do not unify subcomputations with different parameter shapes. +TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { + auto hlo_module = MakeUnique("test_module"); + auto builder = HloComputation::Builder(TestName()); + + auto callee1 = hlo_module->AddEmbeddedComputation( + CreateR1S32AdditionComputation(r1s32_5_)); + auto callee2 = hlo_module->AddEmbeddedComputation( + CreateR1S32AdditionComputation(r1s32_3_)); + + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1s32_5_, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1s32_5_, "param2")); + auto x = builder.AddInstruction( + HloInstruction::CreateCall(r1s32_5_, {param1, param1}, callee1)); + auto y = builder.AddInstruction( + HloInstruction::CreateCall(r1s32_3_, {param2, param2}, callee2)); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, {8}), {x, y}, 0)); + + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_NE(x->to_apply(), y->to_apply()); + if (VLOG_IS_ON(1)) { + hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + "before unification", false, false, nullptr); + } + EXPECT_FALSE( + HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie()); + if (VLOG_IS_ON(1)) { + hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + "after unification", false, false, nullptr); + } + EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_NE(x->to_apply(), y->to_apply()); +} + +// Regression test for b/31466798. Checks that entry_computation is still valid +// after unification. +TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { + HloModule module(TestName()); + for (int i = 0; i < 2; ++i) { + HloComputation::Builder builder("pow"); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto y = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y)); + if (i == 0) { + module.AddEmbeddedComputation(builder.Build()); + } else { + module.AddEntryComputation(builder.Build()); + } + } + + EXPECT_TRUE(HloSubcomputationUnification().Run(&module).ValueOrDie()); + EXPECT_EQ(1, module.computations().size()); + EXPECT_EQ(module.computations().front().get(), module.entry_computation()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc new file mode 100644 index 0000000000..bf6e3dd84d --- /dev/null +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/inliner.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +// InlinerVisitor traverses the HLO computation and inlines maps. +class InlinerVisitor : public DfsHloVisitorWithDefault { + public: + explicit InlinerVisitor(HloComputation* computation) + : computation_(computation) {} + + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice static_operands) override; + + // Runs the visitor on a computation. + StatusOr Run(HloComputation* computation); + + private: + // Current HloComputation instance the InlinerVisitor is traversing. + HloComputation* computation_; + + // Whether algebraic simplification has occurred. + bool changed_ = false; +}; + +StatusOr InlinerVisitor::Run(HloComputation* computation) { + changed_ = false; + computation_ = computation; + TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); + return changed_; +} + +Status InlinerVisitor::HandleMap( + HloInstruction* map, tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice /*static_operands*/) { + HloInstruction& root = *function->root_instruction(); + // TODO(b/29249531): Add DCE pass to remove unused HloComputations. + // Only inlining functions that are simply a single operation until a better + // profitability model for inlining is defined. + if (hlo_query::AllOperandsAreParameters(root)) { + if (root.opcode() == HloOpcode::kUpdate || + root.opcode() == HloOpcode::kFusion || + root.opcode() == HloOpcode::kIndex || + root.opcode() == HloOpcode::kParameter || + root.opcode() == HloOpcode::kTrace) { + // Cloning not supported for these instructions. + return Status::OK(); + } + VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " + << root.ToShortString(); + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + if (root.opcode() != HloOpcode::kConstant) { + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), operands)); + computation_->ReplaceInstruction(map, placed_instruction); + } else { + // The constant is in an embedded computation and needs to be recreated + // as part of the computation that the broadcast is inserted into. + HloInstruction* constant = computation_->AddInstruction(root.Clone()); + HloInstruction* placed_instruction = computation_->AddInstruction( + HloInstruction::CreateBroadcast(map->shape(), constant, {})); + computation_->ReplaceInstruction(map, placed_instruction); + } + changed_ = true; + return Status::OK(); + } + + return Status::OK(); +} + +StatusOr Inliner::Run(HloModule* module) { + InlinerVisitor visitor(/*computation=*/nullptr); + bool changed = false; + for (const std::unique_ptr& computation : + module->computations()) { + TF_ASSIGN_OR_RETURN(bool computation_changed, + visitor.Run(computation.get())); + changed |= computation_changed; + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h new file mode 100644 index 0000000000..5d53443b83 --- /dev/null +++ b/tensorflow/compiler/xla/service/inliner.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// A pass which performs inlining. Which can result, for example, in functions +// that were previously being mapped by Map instead directly applied to the +// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). +class Inliner : public HloPass { + public: + Inliner() : HloPass("inline") {} + ~Inliner() override = default; + + // Run inlining on the given computation. Returns whether the computation was + // changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc new file mode 100644 index 0000000000..0054edcf6a --- /dev/null +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/inliner.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using InlinerTest = HloTestBase; + +// Test that `map` with `max` is transformed to `max` +TEST_F(InlinerTest, MapMax) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto max_builder = HloComputation::Builder(TestName()); + auto param1 = max_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")); + auto param2 = max_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "y")); + max_builder.AddInstruction(HloInstruction::CreateBinary( + param1->shape(), HloOpcode::kMaximum, param1, param2)); + auto max_f32 = max_builder.Build(); + + auto builder = HloComputation::Builder("MapMaxFunction"); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = MakeUnique("test_module"); + hlo_module->AddEmbeddedComputation(std::move(max_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + root = hlo_module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMaximum); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +// Test that `constant` function is changed to `broadcast`. +TEST_F(InlinerTest, MapConstant) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto const2_builder = HloComputation::Builder(TestName()); + auto param1 = const2_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")); + (void)param1; + const2_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto const2_f32 = const2_builder.Build(); + + auto builder = HloComputation::Builder("MapConstFunction"); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = MakeUnique("test_module"); + hlo_module->AddEmbeddedComputation(std::move(const2_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + root = hlo_module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc new file mode 100644 index 0000000000..58f8dc92cc --- /dev/null +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -0,0 +1,295 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/instruction_fusion.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +bool IsExpensive(const HloInstruction& instruction) { + switch (instruction.opcode()) { + // Cheap instructions. + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kConcatenate: + case HloOpcode::kConstant: + case HloOpcode::kConvert: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGetTupleElement: + case HloOpcode::kGt: + case HloOpcode::kInfeed: + case HloOpcode::kLe: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalNot: + case HloOpcode::kLogicalOr: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSlice: + case HloOpcode::kSubtract: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return false; + + // Expensive instructions. + case HloOpcode::kCall: + case HloOpcode::kConvolution: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kCustomCall: + case HloOpcode::kDivide: + case HloOpcode::kDot: + case HloOpcode::kExp: + case HloOpcode::kFusion: + case HloOpcode::kIndex: + case HloOpcode::kLog: + case HloOpcode::kMap: + case HloOpcode::kParameter: + case HloOpcode::kPower: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kRng: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: + case HloOpcode::kTanh: + case HloOpcode::kTrace: + case HloOpcode::kUpdate: + case HloOpcode::kWhile: + case HloOpcode::kSend: + case HloOpcode::kRecv: + return true; + } +} + +bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) { + return !(producer->users().size() == 1 && + producer->users().count(consumer) == 1); +} + +StatusOr InstructionFusion::Run(HloModule* module) { + bool changed = false; + for (auto& computation : module->computations()) { + computation_ = computation.get(); + + // We want to be able to remove arbitrary instructions from the post order + // and also compare positions of instructions in the post order. To make + // this possible, create vector of instructions in post order and create a + // map from HloInstruction* to the instruction's index in the vector. An + // instruction is "removed" from the vector by setting it's element to + // nullptr. + std::list post_order_list = + computation_->MakeInstructionPostOrder(); + std::vector post_order(post_order_list.begin(), + post_order_list.end()); + tensorflow::gtl::FlatMap post_order_index; + for (int i = 0; i < post_order.size(); ++i) { + InsertOrDie(&post_order_index, post_order[i], i); + } + + // Instruction fusion effectively fuses edges in the computation graph + // (producer instruction -> consumer instruction) so we iterate over all + // edges. When we fuse an edge, we create a copy of the producer inside the + // fusion instruction. + while (!post_order.empty()) { + // We want to iterate in reverse post order, so remove from the back of + // the vector. + HloInstruction* instruction = post_order.back(); + post_order.pop_back(); + + // Instructions are "removed" from the post order by nulling out the + // element in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + if (instruction == nullptr) { + continue; + } + + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index.erase(instruction); + + if (!instruction->IsFusable() && + instruction->opcode() != HloOpcode::kFusion) { + continue; + } + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid created duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the reverse + // order they appear in the instruction post order. In the example, this + // ensures that B will be considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector sorted_operand_numbers(instruction->operands().size()); + std::iota(std::begin(sorted_operand_numbers), + std::end(sorted_operand_numbers), 0); + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher indices in the post order come + // first. + return ( + FindOrDie(post_order_index, instruction->mutable_operand(i)) > + FindOrDie(post_order_index, instruction->mutable_operand(j))); + }); + + for (int64 i : sorted_operand_numbers) { + HloInstruction* operand = instruction->mutable_operand(i); + if (operand->IsFusable() && ShouldFuse(instruction, i)) { + HloInstruction* fusion_instruction = Fuse(operand, instruction); + + // Fusing an instruction into a fusion instruction can change the + // operand set of the fusion instruction. For simplicity just push the + // instruction to the top of the post_order and reconsider it for + // further fusion in the next iteration of the outer loop. + post_order.push_back(fusion_instruction); + InsertOrDie(&post_order_index, fusion_instruction, + post_order.size() - 1); + changed = true; + + if (operand->user_count() == 0) { + // Operand is now dead. Remove from post order by setting it's + // location to nullptr. + post_order[FindOrDie(post_order_index, operand)] = nullptr; + post_order_index.erase(operand); + + // Remove from computation. + computation_->RemoveInstruction(operand); + } + break; + } + } + } + } + return changed; +} + +HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, + HloInstruction* consumer) { + HloInstruction* fusion_instruction; + + VLOG(2) << "Fusing " << producer << " into " << consumer; + + if (consumer->opcode() == HloOpcode::kFusion) { + fusion_instruction = consumer; + } else { + fusion_instruction = + computation_->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), ChooseKind(producer, consumer), consumer)); + computation_->ReplaceInstruction(consumer, fusion_instruction); + } + fusion_instruction->FuseInstruction(producer); + + return fusion_instruction; +} + +bool InstructionFusion::ShouldFuse(HloInstruction* consumer, + int64 operand_index) { + HloInstruction* producer = consumer->mutable_operand(operand_index); + // Cost condition: don't duplicate expensive instructions. + if (FusionWouldDuplicate(producer, consumer) && + (IsExpensive(*producer) || !may_duplicate_)) { + return false; + } + + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() != HloInstruction::FusionKind::kLoop && + consumer->fusion_kind() != HloInstruction::FusionKind::kInput) { + return false; + } + + // Cost condition: not fuse (expensive producers) and (consumers who reuse + // operand elements). + if (consumer->ReusesOperandElements(operand_index) && + IsExpensive(*producer)) { + return false; + } + + if (producer->CouldBeBitcast() && + // We can't fuse parameters anyhow, so we leave the user unfused to become + // a bitcast. If the operand is not a parameter, we would break a + // potential fusion to make it a bitcast, which is not so clear a win. + producer->operand(0)->opcode() == HloOpcode::kParameter) { + return false; + } + + return true; +} + +HloInstruction::FusionKind InstructionFusion::ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) { + return HloInstruction::FusionKind::kLoop; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h new file mode 100644 index 0000000000..902df2dcd0 --- /dev/null +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Returns true if the computation of the given instruction is significantly +// more expensive than just writing all the values of the instructions' result +// array. Expensive operations should not be duplicated. +bool IsExpensive(const HloInstruction& instruction); + +// Returns true if fusing producer into consumer would cause producer to be +// duplicated. This is the case if producer has uses other than consumer. +bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); + +// HLO pass which performs instruction fusion. Instructions are fused +// "vertically", meaning producing instructions are fused into their consumers +// with the intent that the loops which compute their values will be fused in +// code generation. Derived classes define ShouldFuse method to select which +// instructions to fuse. +class InstructionFusion : public HloPass { + public: + explicit InstructionFusion(bool may_duplicate = true) + : HloPass("fusion"), may_duplicate_(may_duplicate) {} + ~InstructionFusion() override {} + + // Run instruction fusion on the given computation. Returns whether the + // computation was changed (instructions were fused). + StatusOr Run(HloModule* module) override; + + protected: + // Returns whether the given producer instruction should be fused into the + // given consumer instruction. producer is necessarily an operand of consumer. + // Derived classes should define this method to specify which instructions + // should be fused. `operand_index` is which operand of the consumer the + // producer is. + // + // Instructions are traversed in reverse post order (computation root to + // leaves). This method is called for each operand of the instruction (where + // the operand is 'producer' and the instruction is 'consumer') + // + // Subtypes can override this with target-specific heuristics. + virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index); + + // Chooses a fusion kind for `producer` and `consumer`. + // Default method chooses `kLoop`. + virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, + const HloInstruction* consumer); + + // Current HloComputation instance the loop fuser is traversing. + HloComputation* computation_; + + private: + HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + + // Returns whether we may duplicate an instruction if we want to fuse it. + bool may_duplicate_; + + TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc new file mode 100644 index 0000000000..2e3742ed75 --- /dev/null +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/instruction_fusion.h" + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +using InstructionFusionTest = HloTestBase; + +TEST_F(InstructionFusionTest, + CostlyProducerAndOperandElementReusingConsumerNotFused) { + HloComputation::Builder builder(TestName()); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); + HloInstruction* broadcast2 = + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(S32, {1}), exp1, {0})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(broadcast2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + EXPECT_EQ(broadcast2, computation->root_instruction()); +} + +TEST_F(InstructionFusionTest, + NonCostlyProducerAndOperandElementReusingConsumerFused) { + HloComputation::Builder builder(TestName()); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); + HloInstruction* broadcast2 = + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(S32, {1}), negate1, {0})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(broadcast2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); +} + +TEST_F(InstructionFusionTest, + CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { + HloComputation::Builder builder(TestName()); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); + HloInstruction* reshape2 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(reshape2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); +} + +TEST_F(InstructionFusionTest, + CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { + HloComputation::Builder builder(TestName()); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); + HloInstruction* transpose2 = builder.AddInstruction( + HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(transpose2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); +} + +TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(reshape1, computation->root_instruction()); + EXPECT_FALSE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(reshape1, computation->root_instruction()); + EXPECT_FALSE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); + auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {}), param0, {})); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(transpose1, computation->root_instruction()); + EXPECT_FALSE( + InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc new file mode 100644 index 0000000000..a8f2a6b89c --- /dev/null +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -0,0 +1,1334 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/layout_assignment.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { + +std::ostream& operator<<(std::ostream& out, + const LayoutConstraint& constraint) { + out << constraint.ToString(); + return out; +} + +BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, + const LogicalBuffer& buffer) + : layout_(layout), buffer_(&buffer) { + CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); +} + +string BufferLayoutConstraint::ToString() const { + return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", + buffer_->ToString().c_str(), + LayoutUtil::HumanString(layout_).c_str()); +} + +OperandLayoutConstraint::OperandLayoutConstraint( + const ShapeLayout& shape_layout, const HloInstruction* instruction, + int64 operand_no) + : shape_layout_(shape_layout), + instruction_(instruction), + operand_no_(operand_no) { + CHECK(shape_layout_.LayoutIsSet()); + CHECK(ShapeUtil::Compatible(shape_layout.shape(), + instruction->operand(operand_no)->shape())); +} + +string OperandLayoutConstraint::ToString() const { + return tensorflow::strings::Printf( + "OperandLayoutConstraint %s, operand %lld: %s", + instruction_->name().c_str(), operand_no_, + shape_layout_.ToString().c_str()); +} + +string ResultLayoutConstraint::ToString() const { + return tensorflow::strings::Printf("ResultLayoutConstraint: %s", + shape_layout_.ToString().c_str()); +} + +LayoutConstraints::LayoutConstraints( + const TuplePointsToAnalysis& points_to_analysis, + const HloComputation* computation) + : points_to_analysis_(points_to_analysis), computation_(computation) { + // Gather all array-shaped logical buffers into unconstrained_buffer_ids. + for (auto& buffer : points_to_analysis_.logical_buffers()) { + if (buffer->IsArray()) { + unconstrained_buffer_ids_.insert(buffer->id()); + } + } +} + +bool LayoutConstraints::OperandBufferForwarded( + const HloInstruction* instruction, int64 operand_no) const { + // The operand is potentially forwarded if the intersection of points-to sets + // of the operand and the instruction is non-empty. + auto output_buffers = + points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); + auto operand_buffers = + points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) + .CreateFlattenedSet(); + std::vector intersection; + std::set_intersection(output_buffers.begin(), output_buffers.end(), + operand_buffers.begin(), operand_buffers.end(), + std::back_inserter(intersection)); + return !intersection.empty(); +} + +Status LayoutConstraints::SetBufferLayout(const Layout& layout, + const LogicalBuffer& buffer) { + VLOG(3) << "SetBufferLayout : " << buffer << " : " + << LayoutUtil::HumanString(layout); + + TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer)); + if (!buffer.IsArray()) { + return FailedPrecondition( + "Layout of buffer %s cannot be constrained because buffer is not " + "array-shaped, has shape: %s", + buffer.ToString().c_str(), + ShapeUtil::HumanString(buffer.shape()).c_str()); + } + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); + + const Layout* curr_layout = BufferLayout(buffer); + if (curr_layout != nullptr) { + if (!LayoutUtil::Equal(*curr_layout, layout)) { + return FailedPrecondition( + "Buffer %s already has the layout constraint %s, cannot add " + "incompatible constraint %s", + buffer.ToString().c_str(), + LayoutUtil::HumanString(*curr_layout).c_str(), + LayoutUtil::HumanString(layout).c_str()); + } + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + + auto new_constraint_it = buffer_constraints_.insert( + {&buffer, BufferLayoutConstraint(layout, buffer)}); + added_constraints_.push_back(&new_constraint_it.first->second); + + // Remove buffer from the set of unconstrained buffers. + TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == 1); + unconstrained_buffer_ids_.erase(buffer.id()); + + return Status::OK(); +} + +Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, + const HloInstruction* instruction, + int64 operand_no) { + VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " + << operand_no << " : " + << ShapeUtil::HumanStringWithLayout(shape_with_layout); + + const ShapeLayout* curr_shape_layout = OperandLayout(instruction, operand_no); + if (curr_shape_layout != nullptr) { + if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { + return FailedPrecondition( + "Operand %lld of instruction %s already has a layout constraint " + "%s, cannot add incompatible constraint %s", + operand_no, instruction->name().c_str(), + curr_shape_layout->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + } + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + + // If any buffers in the operand occur in the output of the instruction, then + // return an error. This case is not handled because such a constraint changes + // layouts beyond this immediate use and is complicated to handle. + if (OperandBufferForwarded(instruction, operand_no)) { + return FailedPrecondition( + "Cannot constraint layout of operand %lld of instruction %s " + "because instruction forwards operand's LogicalBuffer(s)", + operand_no, instruction->name().c_str()); + } + + auto key = std::make_pair(instruction, operand_no); + auto new_constraint_it = operand_constraints_.insert( + {key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, + operand_no)}); + added_constraints_.push_back(&new_constraint_it.first->second); + + return Status::OK(); +} + +Status LayoutConstraints::SetArrayOperandLayout( + const Layout& layout, const HloInstruction* instruction, int64 operand_no) { + const HloInstruction* operand = instruction->operand(operand_no); + TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + Shape shape(operand->shape()); + *shape.mutable_layout() = layout; + TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); + return SetOperandLayout(shape, instruction, operand_no); +} + +Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { + VLOG(3) << "SetResultLayout : " + << ShapeUtil::HumanStringWithLayout(shape_with_layout); + + const ShapeLayout* curr_shape_layout = ResultLayout(); + if (curr_shape_layout != nullptr) { + if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { + return FailedPrecondition( + "Result of computation %s already has the layout constraint %s, " + "cannot add incompatible constraint %s", + computation_->name().c_str(), curr_shape_layout->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + } + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + + result_constraint_.reset( + new ResultLayoutConstraint(ShapeLayout(shape_with_layout))); + added_constraints_.push_back(result_constraint_.get()); + + return Status::OK(); +} + +Status LayoutConstraints::SetInstructionLayout( + const Shape& shape_with_layout, const HloInstruction* instruction) { + VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " + << ShapeUtil::HumanStringWithLayout(shape_with_layout); + + if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { + return FailedPrecondition( + "Instruction %s of shape %s cannot be assigned incompatible layout %s", + instruction->name().c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + } + + // Create a BufferLayoutConstraint for each array shape in the output of the + // instruction. + return ShapeUtil::ForEachSubshape( + shape_with_layout, + [this, instruction](const Shape& subshape, + const ShapeIndex& index) -> Status { + // The precondition for this method is that the instruction defines all + // buffers in its output. + auto buffers = + points_to_analysis_.GetPointsToSet(instruction).element(index); + CHECK_EQ(1, buffers.size()); + CHECK_EQ(buffers[0]->instruction(), instruction); + + if (ShapeUtil::IsArray(subshape)) { + return SetBufferLayout(subshape.layout(), *buffers[0]); + } else { + return Status::OK(); + } + }); +} + +const Layout* LayoutConstraints::BufferLayout( + const LogicalBuffer& buffer) const { + auto it = buffer_constraints_.find(&buffer); + return it == buffer_constraints_.end() ? nullptr : &it->second.layout(); +} + +const ShapeLayout* LayoutConstraints::OperandLayout( + const HloInstruction* instruction, int64 operand_no) const { + auto it = operand_constraints_.find(std::make_pair(instruction, operand_no)); + return it == operand_constraints_.end() ? nullptr + : &it->second.shape_layout(); +} + +const ShapeLayout* LayoutConstraints::ResultLayout() const { + return result_constraint_ ? &result_constraint_->shape_layout() : nullptr; +} + +string LayoutConstraints::ToString() const { + string output; + tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", + computation_->name(), ":\n"); + for (auto* instruction : computation_->MakeInstructionPostOrder()) { + tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), + "\n"); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + if (OperandLayout(instruction, i) != nullptr) { + tensorflow::strings::StrAppend( + &output, " operand (", i, "): ", + OperandLayout(instruction, i)->ToString(), "\n"); + } + } + for (const LogicalBuffer* buffer : + points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { + if (BufferLayout(*buffer) != nullptr) { + tensorflow::strings::StrAppend( + &output, " ", buffer->ToString(), " : ", + LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); + } + } + } + + if (ResultLayout() != nullptr) { + tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), + "\n"); + } + return output; +} + +Status LayoutAssignment::AddMandatoryConstraints( + const ComputationLayout& computation_layout, HloComputation* computation, + LayoutConstraints* constraints) { + VLOG(3) << "Adding mandatory layout constraints to computation " + << computation->name(); + + // Constrain layouts of instructions which define values with pre-existing + // layouts. + for (auto& instruction : computation->instructions()) { + Shape const* shape_with_layout = nullptr; + if (instruction->opcode() == HloOpcode::kConstant) { + // Constant layouts must match the layout of their literal. + shape_with_layout = &instruction->literal().shape(); + } else if (instruction->opcode() == HloOpcode::kInfeed) { + // Infeed layouts must match the layout of the original inserted + // instruction. + // TODO(b/31425034): Change infeeds to be more like parameters, with + // shapes in the ComputationLayout. + shape_with_layout = &instruction->shape(); + } else if (instruction->opcode() == HloOpcode::kParameter) { + // Parameter layouts must match the respective layout in + // ComputationLayout. + shape_with_layout = + &computation_layout.parameter_layout(instruction->parameter_number()) + .shape(); + } + if (shape_with_layout != nullptr) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(*shape_with_layout, + instruction.get())); + } + } + + // Constrain layouts of instructions which call computations which have + // already been assigned layouts. Instructions which call computations in a + // parallel element-wise context (eg, map or reduce) do not need layout + // constraints because they operate on scalars. + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCall) { + // kCall instruction operands and output must match the ComputationLayout + // of the called computation. + const ComputationLayout& called_computation_layout = + FindOrDie(computation_layouts_, instruction->to_apply()); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + called_computation_layout.result_layout().shape(), + instruction.get())); + TF_RET_CHECK(instruction->operand_count() == + called_computation_layout.parameter_count()); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + called_computation_layout.parameter_layout(i).shape(), + instruction.get(), i)); + } + } else if (instruction->opcode() == HloOpcode::kWhile) { + // Layout of input and output of kWhile instruction must be equal and must + // match both input and output of body computation. Also, the input of + // condition computation must match kWhile layout. + HloComputation* body = instruction->while_body(); + HloComputation* condition = instruction->while_condition(); + const HloInstruction* init = instruction->operand(0); + const ComputationLayout& body_layout = + FindOrDie(computation_layouts_, body); + const ComputationLayout& condition_layout = + FindOrDie(computation_layouts_, condition); + + // Check a few invariants irrespective of layout. + CHECK_EQ(1, instruction->operand_count()); + CHECK_EQ(1, body->num_parameters()); + CHECK_EQ(1, condition->num_parameters()); + DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), + body_layout.parameter_shape(0))); + DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), + condition_layout.parameter_shape(0))); + DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); + + // Return error if earlier layout assignment of the embedded computations + // has produced conflicting layouts. + if (!ShapeUtil::Equal(body_layout.result_shape(), + body_layout.parameter_shape(0))) { + return InternalError( + "Parameter and result of body computation %s of while instruction " + "%s have different layouts: %s vs %s", + body->name().c_str(), instruction->name().c_str(), + ShapeUtil::HumanString(body_layout.result_shape()).c_str(), + ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + } + if (!ShapeUtil::Equal(body->root_instruction()->shape(), + condition->parameter_instruction(0)->shape())) { + return InternalError( + "Parameter of condition computation %s of while instruction " + "%s does not match body computation %s result: %s vs %s", + condition->name().c_str(), instruction->name().c_str(), + body->name().c_str(), + ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), + ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + } + + // Constrain the output and the operand of the while instruction to match + // the computations. + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + body_layout.result_shape(), instruction.get())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + body_layout.result_shape(), instruction.get(), 0)); + } else if (instruction->opcode() == HloOpcode::kCustomCall) { + // Add constraints for kCustomCall instruction operands and instructions. + // For now we only support row major layouts for all inputs and outputs. + auto row_major_shape = [](const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; + }; + + Shape result_shape(row_major_shape(instruction->shape())); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(result_shape, instruction.get())); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const Shape& operand_shape = instruction->operand(i)->shape(); + // Opaque operands don't get a layout constraint. + if (ShapeUtil::IsOpaque(operand_shape)) { + continue; + } + + Shape row_major_operand_shape(row_major_shape(operand_shape)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + row_major_operand_shape, instruction.get(), i)); + } + } + } + + // Finally set the result layout to match ComputationLayout. + return constraints->SetResultLayout( + computation_layout.result_layout().shape()); +} + +namespace { + +// The operands of a call must match the layouts of parameters in the +// ComputationLayout, and the call instruction itself must match the result +// layout in the ComputationLayout. +Status CheckCallLayout(HloInstruction* call, + const ComputationLayout& computation_layout) { + HloComputation* computation = call->to_apply(); + TF_RET_CHECK(computation->num_parameters() == call->operand_count()); + for (int64 i = 0; i < computation->num_parameters(); ++i) { + TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( + call->operand(i)->shape())); + } + TF_RET_CHECK( + computation_layout.result_layout().MatchesLayoutInShape(call->shape())); + return Status::OK(); +} + +// Custom calls have fixed input and output layouts. +Status CheckCustomCallLayout(HloInstruction* custom_call) { + for (const HloInstruction* operand : custom_call->operands()) { + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); + } + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); + return Status::OK(); +} + +// For a while instruction, all the following layouts must be the same: +// (1) init operand +// (2) condition computation parameter +// (3) body computation parameter +// (4) body computation result +// (5) while instruction result +Status CheckWhileLayout(HloInstruction* while_inst, + const ComputationLayout& condition_computation_layout, + const ComputationLayout& body_computation_layout) { + auto init_shape = while_inst->operand(0)->shape(); + TF_RET_CHECK( + condition_computation_layout.parameter_layout(0).MatchesLayoutInShape( + init_shape)); + TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape( + init_shape)); + TF_RET_CHECK( + body_computation_layout.result_layout().MatchesLayoutInShape(init_shape)); + TF_RET_CHECK( + LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape())); + return Status::OK(); +} + +// Fusion parameters must match the layout of the fusion instructions operands, +// and the root of the fusion expression must match the layout of the fusion +// instruction. +Status CheckFusionLayout(HloInstruction* fusion) { + TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode()); + + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + fusion->shape(), fusion->fused_expression_root()->shape())); + for (int64 i = 0; i < fusion->operand_count(); ++i) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape())); + } + return Status::OK(); +} + +// The layout of a parameter must match the respective layout in the +// computation's ComputationLayout. +Status CheckParameterLayout(HloInstruction* parameter, + const ComputationLayout& computation_layout) { + const ShapeLayout& parameter_layout = + computation_layout.parameter_layout(parameter->parameter_number()); + if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) { + return InternalError( + "parameter instruction %s does not match layout of computation " + "shape: %s", + parameter->ToString().c_str(), parameter_layout.ToString().c_str()); + } + return Status::OK(); +} + +// The layout of a constant instruction must match the layout of its literal. +Status CheckConstantLayout(HloInstruction* constant) { + if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(), + constant->shape())) { + return InternalError( + "constant instruction %s does not match the layout of its literal %s", + constant->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); + } + return Status::OK(); +} + +// Check that all layouts in the module have been set and satisfy all necessary +// conditions. +Status CheckLayouts( + HloModule* module, + const std::map& computation_layouts) { + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto& computation : module->computations()) { + for (auto& instruction : computation->instructions()) { + // Verify every instruction has a layout and the layout is valid for the + // shape. + TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + + // Use points-to analysis to verify that every subshape element in the + // output of the instruction matches the layout of the logical buffer + // which could be the source of the subshape value. + const PointsToSet& points_to_set = + points_to_analysis->GetPointsToSet(instruction.get()); + TF_RETURN_IF_ERROR(points_to_set.ForEachElement( + [&instruction]( + ShapeIndex index, bool is_leaf, + const std::vector& buffers) -> Status { + if (is_leaf) { + const Shape& instruction_subshape = + ShapeUtil::GetSubshape(instruction->shape(), index); + for (const LogicalBuffer* buffer : buffers) { + if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) { + return InternalError( + "Layout of instruction %s at index {%s} does not match " + "source LogicalBuffer %s: %s vs %s", + instruction->name().c_str(), + tensorflow::str_util::Join(index, ",").c_str(), + buffer->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction_subshape) + .c_str(), + ShapeUtil::HumanStringWithLayout(buffer->shape()) + .c_str()); + } + } + } + return Status::OK(); + })); + + // Verify instructions that have special layout constraints. + switch (instruction->opcode()) { + case HloOpcode::kCall: + TF_RETURN_IF_ERROR(CheckCallLayout( + instruction.get(), + FindOrDie(computation_layouts, instruction->to_apply()))); + break; + case HloOpcode::kCustomCall: + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction.get())); + break; + case HloOpcode::kFusion: + TF_RETURN_IF_ERROR(CheckFusionLayout(instruction.get())); + break; + case HloOpcode::kParameter: + TF_RETURN_IF_ERROR(CheckParameterLayout( + instruction.get(), + FindOrDie(computation_layouts, instruction->parent()))); + break; + case HloOpcode::kConstant: + TF_RETURN_IF_ERROR(CheckConstantLayout(instruction.get())); + break; + case HloOpcode::kWhile: + TF_RETURN_IF_ERROR(CheckWhileLayout( + instruction.get(), + FindOrDie(computation_layouts, instruction->while_condition()), + FindOrDie(computation_layouts, instruction->while_body()))); + break; + default: + break; + } + } + } + + // Finally verify the result layout matches the layout of the entry + // computation root. + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + FindOrDie(computation_layouts, module->entry_computation()) + .result_layout() + .shape())); + + return Status::OK(); +} + +} // namespace + +LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) + : HloPass("layout-assignment"), + entry_computation_layout_(entry_computation_layout) { + VLOG(1) << "entry computation layout given to layout assignment: " + << entry_computation_layout_->ToString(); + // Layouts of all parameter instructions must be set. + for (const ShapeLayout& parameter_layout : + entry_computation_layout_->parameter_layouts()) { + CHECK(parameter_layout.LayoutIsSet()); + } + // If the result layout is not set, then choose the default. + // TODO(b/29118294): Choose a better layout in this case. + if (!entry_computation_layout_->result_layout().LayoutIsSet()) { + entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout(); + } +} + +namespace { + +// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of +// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the +// indices `to_delete` wherever in `indices` they are, and inserting the indices +// `to_insert` arbitrarily at the back. +tensorflow::protobuf::RepeatedField +DeleteAndInsertIndices( + std::vector to_delete, std::vector to_insert, + tensorflow::protobuf::RepeatedField indices) { + std::sort(to_delete.begin(), to_delete.end(), std::greater()); + std::sort(to_insert.begin(), to_insert.end(), std::less()); + for (auto index : to_delete) { + auto i = indices.begin(); + while (i != indices.end()) { + if (*i == index) { + i = indices.erase(i); + } else { + if (*i > index) { + (*i)--; + } + ++i; + } + } + } + for (auto index : to_insert) { + for (auto i = indices.begin(); i != indices.end(); ++i) { + if (*i >= index) { + (*i)++; + } + } + indices.Add(index); + } + return indices; +} + +} // namespace + +std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( + const Layout& output_layout, const HloInstruction* instruction, + int64 operand_no) { + const HloInstruction* operand = instruction->operand(operand_no); + + CHECK(ShapeUtil::IsArray(instruction->shape()) && + ShapeUtil::IsArray(operand->shape())); + + if (instruction->IsElementwiseOnOperand(operand_no) && + !ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == + ShapeUtil::Rank(instruction->shape())) { + // Assign operands the same layout as the instruction, so that + // 1) the elementwise operation can reuse its operand's buffer, and + // 2) the input and output elements can reuse the same linear index. + // + // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit + // from assigning the same layout to input and output. + return MakeUnique(output_layout); + } + + if (instruction->opcode() == HloOpcode::kReshape) { + // Pick the operand layout that makes the reshape a bitcast. If the reshape + // only inserts or deletes degenerate dimensions, we can easily compute the + // desired layout by accordingly inserting and deleting the elements in the + // minor-to-major list. + bool merely_inserts_or_deletes_1_sized_dims; + std::vector inserted_indices, deleted_indices; + std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, + inserted_indices) = + instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); + if (merely_inserts_or_deletes_1_sized_dims) { + Layout operand_layout = LayoutUtil::MakeLayout( + AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices, + output_layout.minor_to_major()))); + TF_CHECK_OK( + LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); + return MakeUnique(operand_layout); + } + } + + if (instruction->opcode() == HloOpcode::kTranspose) { + // Pick the operand layout that makes the transpose a bitcast. + std::vector perm = + ComposePermutations(instruction->dimensions(), + AsInt64Slice(output_layout.minor_to_major())); + Layout operand_layout = LayoutUtil::MakeLayout(perm); + TF_CHECK_OK( + LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); + return MakeUnique(operand_layout); + } + + return nullptr; +} + +std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( + const Layout& operand_layout, const HloInstruction* user, + int64 operand_no) { + const HloInstruction* operand = user->operand(operand_no); + + CHECK(ShapeUtil::IsArray(user->shape()) && + ShapeUtil::IsArray(operand->shape())); + + if (user->IsElementwiseOnOperand(operand_no) && + !ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { + // Assign users the same layout as the operand. + return MakeUnique(operand_layout); + } + + if (user->opcode() == HloOpcode::kReshape) { + // Pick the user layout that makes the reshape a bitcast. + bool merely_inserts_or_deletes_1_sized_dims; + std::vector inserted_indices, deleted_indices; + std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, + inserted_indices) = + user->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); + if (merely_inserts_or_deletes_1_sized_dims) { + Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice( + DeleteAndInsertIndices(deleted_indices, inserted_indices, + operand_layout.minor_to_major()))); + TF_CHECK_OK( + LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); + return MakeUnique(user_layout); + } + } + + if (user->opcode() == HloOpcode::kTranspose) { + // Pick the user layout that makes the reshape a bitcast. + // To become a bitcast, the layouts need to satisfy + // collapsing_order * output_layout = input_layout + // so output_layout = inverse(collapsing_order) * input_layout + std::vector perm = + Permute(InversePermutation(user->dimensions()), + AsInt64Slice(operand_layout.minor_to_major())); + Layout user_layout = LayoutUtil::MakeLayout(perm); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); + return MakeUnique(user_layout); + } + + return nullptr; +} + +Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) { + // Gathers all initial constraints in a worklist and propagates them in + // depth-first order. DFS order seems to be better than BFS because a + // constraint is propagated as far as possible before propagating unrelated + // constraints which makes it less likely that conflicting constraints will be + // propagated to instructions. However, we should experiment with other orders + // too. + std::deque worklist; + + // Lambda for moving newly added constraints to the worklist. + auto add_new_constraints_to_worklist = [constraints, &worklist]() { + // Add constraints to the front of the deque for DFS ordering. + for (auto* constraint : constraints->ConsumeAddedConstraints()) { + worklist.push_front(constraint); + } + }; + add_new_constraints_to_worklist(); + + while (!worklist.empty()) { + const LayoutConstraint* layout_constraint = worklist.front(); + worklist.pop_front(); + VLOG(2) << "Propagating " << layout_constraint->ToString() + << " to its neighbors."; + if (auto* buffer_constraint = + dynamic_cast(layout_constraint)) { + TF_RETURN_IF_ERROR( + PropagateBufferConstraint(*buffer_constraint, constraints)); + } else if (auto* operand_constraint = + dynamic_cast( + layout_constraint)) { + TF_RETURN_IF_ERROR( + PropagateOperandConstraint(*operand_constraint, constraints)); + } else if (auto* result_constraint = + dynamic_cast( + layout_constraint)) { + TF_RETURN_IF_ERROR( + PropagateResultConstraint(*result_constraint, constraints)); + } else { + LOG(FATAL) << "Invalid constraint type: " << *layout_constraint; + } + + add_new_constraints_to_worklist(); + } + return Status::OK(); +} + +namespace { + +// Returns a vector containing all array-shaped uses (instruction and operand +// number) of the given logical buffer or its aliases. +std::vector> GetArrayUsesOfBuffer( + const LogicalBuffer& buffer, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK(buffer.IsArray()); + std::vector> uses; + for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) { + if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) { + continue; + } + // This alias must be the top-level (index == {}) of the instruction's + // result because the instruction produces an array. + CHECK(buffer_alias.index().empty()); + + // Add all uses of the instruction's output. + for (const HloInstruction* user : buffer_alias.instruction()->users()) { + for (int64 operand_no : + user->OperandIndices(buffer_alias.instruction())) { + uses.emplace_back(user, operand_no); + } + } + } + return uses; +} + +} // namespace + +Status LayoutAssignment::PropagateUseConstraintToDefs( + const ShapeLayout& shape_layout, const HloInstruction* instruction, + LayoutConstraints* constraints) { + // Try to set all logical buffers which may be sources of the given operand to + // match the given layout. + const PointsToSet& points_to_set = + constraints->points_to_analysis().GetPointsToSet(instruction); + return points_to_set.ForEachElement( + [this, &shape_layout, constraints]( + const ShapeIndex& index, bool is_leaf, + const std::vector& buffers) -> Status { + if (is_leaf) { + for (const LogicalBuffer* buffer : buffers) { + if (constraints->BufferLayout(*buffer) == nullptr && + ShapeUtil::IsArray(buffer->shape())) { + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), + *buffer)); + } + } + } + return Status::OK(); + }); +} + +Status LayoutAssignment::PropagateOperandConstraint( + const OperandLayoutConstraint& operand_constraint, + LayoutConstraints* constraints) { + // Try to set the layout of the logical buffers in the given operand to match + // the constrained layout. This avoids copies. + TF_RETURN_IF_ERROR( + PropagateUseConstraintToDefs(operand_constraint.shape_layout(), + operand_constraint.operand(), constraints)); + + // For array-shaped operands and user instructions try to pick a minimum cost + // layout. For example, if the operand of a elementwise instruction is + // constained to a certain layout we want the output of the instruction to + // have the same layout. + const HloInstruction* operand = operand_constraint.operand(); + const HloInstruction* user = operand_constraint.instruction(); + if (!ShapeUtil::IsArray(operand->shape()) || + !ShapeUtil::IsArray(user->shape())) { + return Status::OK(); + } + + // Only try to choose a low cost layout if the instruction 'user' defines its + // output (ie, doesn't forward a buffer from elsewhere). + if (constraints->OperandBufferForwarded(user, + operand_constraint.operand_no())) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{})); + + if (constraints->BufferLayout(*buffer) == nullptr) { + std::unique_ptr layout = ChooseOutputLayoutFromOperandLayout( + operand_constraint.shape_layout().layout(), user, + operand_constraint.operand_no()); + if (layout != nullptr) { + TF_RETURN_IF_ERROR(constraints->SetBufferLayout(*layout, *buffer)); + } + } + return Status::OK(); +} + +Status LayoutAssignment::PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + // Only propagate array layouts. + const LogicalBuffer& buffer = buffer_constraint.buffer(); + if (!buffer.IsArray()) { + return Status::OK(); + } + + // If this buffer is the result of an array-shaped op (as opposed to an array + // element in a tuple) try to propagate the layout to its operands. + if (buffer.IsTopLevel()) { + const HloInstruction* instruction = buffer.instruction(); + // Propagate the def-constraint on an instruction to the use-constraints on + // its operands (use-def propagation). + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + if (constraints->OperandLayout(instruction, operand_no) == nullptr && + ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + std::unique_ptr operand_layout = + ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), + instruction, operand_no); + if (operand_layout != nullptr) { + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + *operand_layout, instruction, operand_no)); + } + } + } + } + + // Propagate the layout to all array uses of the logical buffer. This skips + // uses of the buffer where the buffer is the element of a tuple. + for (const auto& user_operand_no : + GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) { + const HloInstruction* user = user_operand_no.first; + int64 operand_no = user_operand_no.second; + // Only add an operand constraint if the user does not forward the buffer + // because this case is not handled is SetOperandLayout. + if (constraints->OperandLayout(user, operand_no) == nullptr && + !constraints->OperandBufferForwarded(user, operand_no)) { + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + buffer_constraint.layout(), user, operand_no)); + } + } + + return Status::OK(); +} + +Status LayoutAssignment::PropagateResultConstraint( + const ResultLayoutConstraint& result_constraint, + LayoutConstraints* constraints) { + // Propagate the use constraint of the root instruction up to the logical + // buffers which make up the result. + return PropagateUseConstraintToDefs( + result_constraint.shape_layout(), + constraints->computation()->root_instruction(), constraints); +} + +namespace { + +// Infers the layout of the array at the given index in the given instruction's +// output using points-to analysis. Precondition: The given instruction must +// not produce this array value (that is, the array is forwarded from the +// instruction's operands). +StatusOr InferArrayLayout( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* instruction, const ShapeIndex& index) { + // This function should only be called for array shapes which don't yet have + // layouts. + const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index); + TF_RET_CHECK(ShapeUtil::IsArray(subshape)); + TF_RET_CHECK(!subshape.has_layout()); + + // The instruction should not define the buffer at this index. + TF_RET_CHECK( + !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index)); + + const std::vector& source_buffers = + points_to_analysis.GetPointsToSet(instruction).element(index); + TF_RET_CHECK(!source_buffers.empty()); + + // Verify the layout is the same for every LogicalBuffer which this location + // ('instruction' and 'index') points to. + const Layout* first_buffer_layout = nullptr; + for (const LogicalBuffer* source_buffer : source_buffers) { + if (!source_buffer->shape().has_layout()) { + // This should not happen because we've assigned layouts to all + // instructions preceding this one. + return InternalError("LogicalBuffer %s does not have a layout", + source_buffer->ToString().c_str()); + } + + if (first_buffer_layout == nullptr) { + first_buffer_layout = &source_buffer->shape().layout(); + } else if (!LayoutUtil::Equal(source_buffer->shape().layout(), + *first_buffer_layout)) { + // The points-to set is ambiguous for this index and the different source + // buffers have different layouts. This case is possible in valid XLA + // computations because we do not propagate BufferLayoutConstaints to all + // LogicalBuffers which may alias the constrained LogicalBuffer at some + // point in the computation. + return FailedPrecondition( + "Array at index {%s} in instruction %s aliases buffers %s " + "and %s which have different layouts", + tensorflow::str_util::Join(index, ",").c_str(), + instruction->name().c_str(), source_buffers[0]->ToString().c_str(), + source_buffer->ToString().c_str()); + } + } + + return *first_buffer_layout; +} + +// Creates and returns a copy of the given instruction with a different +// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple +// instruction producing the copy is returned. +StatusOr CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction) { + TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); + DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())); + + if (ShapeUtil::IsTuple(instruction->shape())) { + // Deep-copy tuples. + std::vector element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + + // Recurse to copy each elements. + TF_ASSIGN_OR_RETURN( + HloInstruction * element_copy, + CreateCopyWithNewLayout( + ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); + element_copies.push_back(element_copy); + } + // Gather element copies into a tuple with a new Tuple instruction. + HloInstruction* tuple_copy = instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); + LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, tuple_copy->mutable_shape())); + return tuple_copy; + } else if (ShapeUtil::IsArray(instruction->shape())) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, copy->mutable_shape())); + + return copy; + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +// Creates a copy of the given operand if the operand's layout does not match +// the given layout. This copy replaces the use in the given instruction. Tuple +// operands will be deep-copied. +Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no) { + HloInstruction* operand = instruction->mutable_operand(operand_no); + TF_RET_CHECK(operand_layout.LayoutIsSet()); + TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); + + if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + // Operand layout already matches our constraint. Nothing to do. + return Status::OK(); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); + + instruction->ReplaceOperandWith(operand_no, operand_copy); + return Status::OK(); +} + +// For fusion instructions, set the layout of each fused parameter instruction +// to match the layout of its corresponding fusion instruction operand. Also, +// set the layout of the fused root to match the layout of the fusion +// instruction itself. +// Fused GetTupleElement requires a layout so that TBAA metadata for the tuple +// element array pointer load can be added. +Status SetFusionLayouts(HloInstruction* fusion) { + TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); + for (auto& fused_instruction : fusion->fused_instructions()) { + if (fused_instruction->opcode() == HloOpcode::kParameter) { + const HloInstruction* fusion_operand = + fusion->operand(fused_instruction->parameter_number()); + DCHECK(ShapeUtil::Compatible(fusion_operand->shape(), + fused_instruction->shape())); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + fusion_operand->shape(), fused_instruction->mutable_shape())); + } else if (fused_instruction.get() == fusion->fused_expression_root()) { + // The layout of the root of the fused expression must match the fusion + // instruction layout. + DCHECK( + ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape())); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + fusion->shape(), fused_instruction->mutable_shape())); + } else if (fused_instruction->opcode() != HloOpcode::kConstant && + fused_instruction->opcode() != HloOpcode::kGetTupleElement && + fused_instruction->opcode() != HloOpcode::kInfeed) { + // Internal fused instructions with the exception of constants + // and infeed need no layout. + LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); + } + } + + return Status::OK(); +} + +} // namespace + +Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, + HloComputation* computation) { + VLOG(2) << "Assigning layouts to computation: " << computation->name(); + XLA_VLOG_LINES(2, computation->ToString()); + XLA_VLOG_LINES(2, constraints.ToString()); + + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + + // Create a copy of an operand if the operand instruction's layout does not + // match the use constraint (OperandLayoutConstraint). + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const ShapeLayout* operand_layout = + constraints.OperandLayout(instruction, operand_no); + if (operand_layout != nullptr) { + TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, + instruction, operand_no)); + } + } + + // Set the layouts of the array shapes this instruction defines as + // indicated by the respective BufferLayoutConstraints. Any array shapes + // in the output of the instruction which are not defined by the instruction + // (eg, array elements in a Tuple instruction) will be assigned below via + // inference. + for (const LogicalBuffer* buffer : + constraints.points_to_analysis().GetBuffersDefinedByInstruction( + instruction)) { + if (!ShapeUtil::IsArray(buffer->shape())) { + continue; + } + + TF_RET_CHECK(buffer->instruction() == instruction); + Shape* buffer_subshape = ShapeUtil::GetMutableSubshape( + instruction->mutable_shape(), buffer->index()); + const Layout* buffer_layout = constraints.BufferLayout(*buffer); + TF_RET_CHECK(buffer_layout != nullptr); + *buffer_subshape->mutable_layout() = *buffer_layout; + } + + // Any remaining layouts in the output of the instruction must be + // inferrable using points-to analysis. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshape( + instruction->mutable_shape(), + [instruction, &constraints](Shape* subshape, const ShapeIndex& index) { + if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) { + return Status::OK(); + } + // Set Layout of subshape to match layout of LogicalBuffer which + // produces it. + TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(), + InferArrayLayout(constraints.points_to_analysis(), + instruction, index)); + return Status::OK(); + })); + + // Fusion instructions require some layouts to be set on fused instructions + // inside the fusion instruction. + if (instruction->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR(SetFusionLayouts(instruction)); + } + + // Verify all layouts in the shape have been set. + TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); + } + + // Copy the root instrucion's result if the it does not match the result + // layout constraint + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } + + return Status::OK(); +} + +Status LayoutAssignment::RunOnComputation( + const ComputationLayout& computation_layout, HloComputation* computation) { + DCHECK(computation_layout.LayoutIsSet()); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() + << ")"; + VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(computation->parent())); + + // Construct LayoutConstaints with all layout constraints of the computation. + LayoutConstraints constraints(*points_to_analysis, computation); + + // Add constraints required for correctness on all backends (eg, entry + // parameter layout constraints). + TF_RETURN_IF_ERROR( + AddMandatoryConstraints(computation_layout, computation, &constraints)); + + // Add any backend-specific constraints. + TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints)); + + // Propagates layouts from an HLO to its neighbors. + TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); + + // While any unconstrained buffers remain, pick an arbitrary buffer, give it a + // layout and propagate the change. + while (!constraints.unconstrained_buffer_ids().empty()) { + int unconstrained_count = constraints.unconstrained_buffer_ids().size(); + + // Arbitrarily pick the first unconstrained buffer and give it the default + // layout. By construction unconstrained_buffers() has a stable sort based + // on LogicalBuffer::Id. + const LogicalBuffer& buffer = points_to_analysis->GetBuffer( + *constraints.unconstrained_buffer_ids().begin()); + TF_RETURN_IF_ERROR(constraints.SetBufferLayout( + LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer)); + + TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); + + // To verify progress has been made, check that the number of unconstrained + // buffers has been reduced. + CHECK_LT(constraints.unconstrained_buffer_ids().size(), + unconstrained_count); + } + + // All logical buffers should have constraints at this point. All that + // remains is assign the constraints to the buffers and infer layouts for + // aliased buffers. + return AssignLayouts(constraints, computation); +} + +StatusOr LayoutAssignment::Run(HloModule* module) { + VLOG(2) << "Running layout assignment on module " << module->name(); + XLA_VLOG_LINES(3, module->ToString()); + if (VLOG_IS_ON(10)) { + hlo_graph_dumper::DumpGraph(*module->entry_computation(), + "before layout assignment", + /*show_addresses=*/false, + /*show_layouts=*/true); + } + + // Assign layouts to computations in an order such that a callee computation + // is handled before its caller computation. This ensures that the layout of + // all callers of a computation will agree. + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_, + module->entry_computation())); + } else { + ComputationLayout computation_layout(computation->ComputeProgramShape()); + // Setting all embedded computations to the default layout is potentially + // suboptimal. + computation_layout.SetToDefaultLayout(); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation)); + } + } + + TF_RETURN_IF_ERROR(CheckLayouts(module, computation_layouts_)); + + VLOG(3) << "After layout assignment:"; + XLA_VLOG_LINES(3, module->ToString()); + if (VLOG_IS_ON(10)) { + hlo_graph_dumper::DumpGraph(*module->entry_computation(), + "after layout assignment", + /*show_addresses=*/false, + /*show_layouts=*/true); + } + + // All layouts are reset then reassigned by this pass. + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h new file mode 100644 index 0000000000..3c67a95412 --- /dev/null +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -0,0 +1,302 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Abstract base class for layout constraints. These constraint objects are +// gathered together in LayoutConstraints object. +class LayoutConstraint { + public: + LayoutConstraint() = default; + virtual ~LayoutConstraint() = default; + + virtual string ToString() const = 0; +}; + +std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); + +// Layout constraint on a single LogicalBuffer. This constrains the layout of an +// array produced by a particular instruction. +class BufferLayoutConstraint : public LayoutConstraint { + public: + BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer); + + const LogicalBuffer& buffer() const { return *buffer_; } + const Layout& layout() const { return layout_; } + + string ToString() const override; + + private: + const Layout layout_; + const LogicalBuffer* buffer_; +}; + +// Constraint on the layout of the operand of an instruction. The constrained +// shape can be arbitrarily shaped (array or tuple). This is a constraint on the +// use of a shaped value and is not a hard constraint on the instruction(s) +// which define the value as copies may be inserted between the definition and +// use. +class OperandLayoutConstraint : public LayoutConstraint { + public: + OperandLayoutConstraint(const ShapeLayout& shape_layout, + const HloInstruction* instruction, int64 operand_no); + + const ShapeLayout& shape_layout() const { return shape_layout_; } + const HloInstruction* instruction() const { return instruction_; } + const int64 operand_no() const { return operand_no_; } + const HloInstruction* operand() const { + return instruction_->operand(operand_no_); + } + + string ToString() const override; + + private: + const ShapeLayout shape_layout_; + const HloInstruction* instruction_; + int64 operand_no_; +}; + +// Constraint on the layout of the result of the entry computation. +class ResultLayoutConstraint : public LayoutConstraint { + public: + explicit ResultLayoutConstraint(const ShapeLayout& shape_layout) + : shape_layout_(shape_layout) {} + + const ShapeLayout& shape_layout() const { return shape_layout_; } + string ToString() const override; + + private: + const ShapeLayout shape_layout_; +}; + +// Class encapsulating the layout constraints of the values in a HLO +// computation. +class LayoutConstraints { + public: + LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis, + const HloComputation* computation); + ~LayoutConstraints() = default; + + const HloComputation* computation() const { return computation_; } + const TuplePointsToAnalysis& points_to_analysis() const { + return points_to_analysis_; + } + + // Return a vector containing the constraints which have been added to the + // LayoutConstraints object since the construction of the object or since the + // last time ConsumeAddedConstraints() has been called. This is used to + // identify + // newly added constraints when propagating layouts. + std::vector ConsumeAddedConstraints() { + std::vector ret_vec(std::move(added_constraints_)); + added_constraints_.clear(); + return ret_vec; + } + void ClearAddedConstraints() { added_constraints_.clear(); } + + // Returns the layout of a LogicalBuffer, the layout of the operand of the + // instruction, or the layout of the result of the computation, respectively, + // if it has been constrained. Otherwise return nullptr. + const Layout* BufferLayout(const LogicalBuffer& buffer) const; + const ShapeLayout* OperandLayout(const HloInstruction* instruction, + int64 operand_no) const; + const ShapeLayout* ResultLayout() const; + + // Add a constraint on the layout of a LogicalBuffer, the layout of the + // operand of the instruction, or the layout of the result of the computation, + // respectively. + Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer); + Status SetOperandLayout(const Shape& shape_with_layout, + const HloInstruction* instruction, int64 operand_no); + Status SetResultLayout(const Shape& shape_with_layout); + + // Convenience wrapper around SetOperandLayout for setting the layout of a + // operand using a Layout object. The operand must be array-shaped. + Status SetArrayOperandLayout(const Layout& layout, + const HloInstruction* instruction, + int64 operand_no); + + // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers + // created by the instruction to the layouts in the given shape. The + // instruction must define every logical buffer in its output. + Status SetInstructionLayout(const Shape& shape_with_layout, + const HloInstruction* instruction); + + // Returns true if any buffer in the given operand is forwarded to the output + // of the given instruction. For example, the Tuple instruction forwards the + // buffers of its operands and would return true for each of its operands. + bool OperandBufferForwarded(const HloInstruction* instruction, + int64 operand_no) const; + + // Returns the set of logical buffers (by LogicalBuffer:Id) which do not + // yet have a layout constraint + const std::set& unconstrained_buffer_ids() const { + return unconstrained_buffer_ids_; + } + + string ToString() const; + + private: + // The set of BufferLayoutConstraints applied to the computation. + std::unordered_map + buffer_constraints_; + + // The set of OperandLayoutConstraints applied to the computation. + using OperandConstraintKey = std::pair; + std::map operand_constraints_; + + // The result constraint for the computation (can be null). + std::unique_ptr result_constraint_; + + // A vector which holds constraints as they are added. Can be cleared with + // ClearAddedConstraints. + std::vector added_constraints_; + + // Points-to analysis for the module. Used to propagate constraints through + // the HLO graph. + const TuplePointsToAnalysis& points_to_analysis_; + + // Array-shaped buffers which have not yet been constrained. + std::set unconstrained_buffer_ids_; + + const HloComputation* computation_; +}; + +// HLO pass which assigns layouts to all instructions in the HLO module while +// satisfying all necessary invariants and minimizing cost. +class LayoutAssignment : public HloPass { + public: + // entry_computation_layout is modified to populate a layout for the result in + // the case that no particular layout is requested. + explicit LayoutAssignment(ComputationLayout* entry_computation_layout); + ~LayoutAssignment() override {} + + // Assign layouts to the given module. Returns whether the module was changed + // (any layouts were changed). + StatusOr Run(HloModule* module) override; + + protected: + // These methods, invoked by PropagateConstraints, propagate a layout + // constraint to its neighbors (i.e. operands and users) in order to minimize + // the cost of the instructions being constrainted on. New constraints are + // added to the given constraint set. + // + // Backends can override these methods with backend-specific propagation + // rules. + virtual Status PropagateBufferConstraint( + const BufferLayoutConstraint& layout_constraint, + LayoutConstraints* constraints); + virtual Status PropagateOperandConstraint( + const OperandLayoutConstraint& layout_constraint, + LayoutConstraints* constraints); + virtual Status PropagateResultConstraint( + const ResultLayoutConstraint& layout_constraint, + LayoutConstraints* constraints); + + private: + // Adds constraints which must be satisfied for correctness on all + // backends. Called once prior to propagating constraints. + Status AddMandatoryConstraints(const ComputationLayout& computation_layout, + HloComputation* computation, + LayoutConstraints* constraints); + + // This method can be overridden to add backend-specific constraints to the + // layout of the instructions of a computation. This method is called after + // all mandatory constraints have been added via AddMandatoryConstraints + // and before propagating constraints. + virtual Status AddBackendConstraints(LayoutConstraints* constraints) { + return Status::OK(); + } + + // Construct contraints and assign layouts to all instructions in the + // computation satisfying the given ComputationLayout. Layouts constraints are + // added, then propagated until all LogicalBuffers in the computation are + // constrained. + Status RunOnComputation(const ComputationLayout& computation_layout, + HloComputation* computation); + + // Assign layouts to the instructions of a computation which satisfy the given + // layout constraints. Copies may be added to satisfy the constraints. The + // given LayoutConstraints must have layout constraints every logical buffer + // in the computation. + Status AssignLayouts(const LayoutConstraints& constraints, + HloComputation* computation); + + // Propagates layout constraints from a set of initial constraints in order to + // minimize the local cost of the computation. This propagation is *not* + // required for correctness. + Status PropagateConstraints(LayoutConstraints* constraints); + + // Propagates a layout constraint on the use of the result of the given + // instruction to the definitions of the LogicalBuffers which make up the + // result. + Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, + const HloInstruction* instruction, + LayoutConstraints* constraints); + + // Chooses a layout of operand `operand_no` of `instruction` that minimizes + // the cost of `instruction`. `output_layout` is the layout of `instruction`. + // Returns null if it can't decide the best layout. + // Precondition: `instruction` and the operand are array-shaped. + std::unique_ptr ChooseOperandLayoutFromOutputLayout( + const Layout& output_layout, const HloInstruction* instruction, + int64 operand_no); + // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of + // `user` that minimizes its cost on that operand. Returns null if it can't + // decide the best layout. + // Precondition: `user` and the operand are array-shaped. + std::unique_ptr ChooseOutputLayoutFromOperandLayout( + const Layout& operand_layout, const HloInstruction* user, + int64 operand_no); + + ComputationLayout* entry_computation_layout_; + + // Map containing the layouts of all computations assigned so + // far. Computations are handled in a topological sort where computations are + // handled before their caller instructions so the layouts of caller + // instructions can be set to match the computation. + std::map computation_layouts_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc new file mode 100644 index 0000000000..6361907b0e --- /dev/null +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -0,0 +1,486 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/layout_assignment.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace { + +class LayoutAssignmentTest : public HloTestBase { + protected: + void AssignLayouts(HloModule* module, + ComputationLayout* entry_computation_layout) { + LayoutAssignment layout_assignment(entry_computation_layout); + EXPECT_IS_OK(layout_assignment.Run(module).status()); + } +}; + +TEST_F(LayoutAssignmentTest, ComputationLayout) { + // Verify the layouts of the root and parameter instructions of a computation + // match the ComputationLayout for two different layouts. + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + for (auto& minor_to_major : minor_to_majors) { + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "param1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build()); + + Layout layout = LayoutUtil::MakeLayout(minor_to_major); + Shape shape(ashape); + *shape.mutable_layout() = layout; + const ShapeLayout shape_layout(shape); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = shape_layout; + *computation_layout.mutable_parameter_layout(1) = shape_layout; + *computation_layout.mutable_result_layout() = shape_layout; + AssignLayouts(&module, &computation_layout); + EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); + } +} + +TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { + // Verify the layouts of the root and parameter instructions of a computation + // match the ComputationLayout which has mixed layout. + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "param1")); + builder.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build()); + + Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); + Shape col_major_shape(ashape); + *col_major_shape.mutable_layout() = col_major_layout; + const ShapeLayout col_major(col_major_shape); + + Layout row_major_layout = LayoutUtil::MakeLayout({0, 1}); + Shape row_major_shape(ashape); + *row_major_shape.mutable_layout() = row_major_layout; + const ShapeLayout row_major(row_major_shape); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = col_major; + *computation_layout.mutable_parameter_layout(1) = row_major; + *computation_layout.mutable_result_layout() = col_major; + + AssignLayouts(&module, &computation_layout); + EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal( + col_major_layout, computation->root_instruction()->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, FusionInstruction) { + // Verify that the layout of the fused parameters in a fusion instruction + // match that of the fusion operands. Other fused instructions should have no + // layout. + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + for (auto& minor_to_major : minor_to_majors) { + auto builder = HloComputation::Builder(TestName()); + auto constant_literal1 = test_utils::CreateR2LiteralWithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major); + auto constant_literal2 = test_utils::CreateR2LiteralWithLayout( + {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major); + Shape ashape = constant_literal1->shape(); + + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(constant_literal1))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(constant_literal2))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + ashape, HloOpcode::kAdd, constant1, constant2)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add)); + auto negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); + + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); + + Layout layout = LayoutUtil::MakeLayout(minor_to_major); + Shape shape(ashape); + *shape.mutable_layout() = layout; + const ShapeLayout shape_layout(shape); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_result_layout() = shape_layout; + + AssignLayouts(&module, &computation_layout); + + EXPECT_TRUE(LayoutUtil::Equal( + layout, fusion->fused_parameter(0)->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal( + layout, fusion->fused_parameter(1)->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal( + layout, fusion->fused_expression_root()->shape().layout())); + + // Inner fused node should not have layout. + EXPECT_FALSE(LayoutUtil::HasLayout( + fusion->fused_expression_root()->operand(0)->shape())); + } +} + +TEST_F(LayoutAssignmentTest, TupleLayout) { + // Verify the layouts of a tuple are assigned properly (the element layouts + // match their source). + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + {0, 1}))); + auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + {1, 0}))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant0, constant1})); + + // To avoid having to construct a tuple layout in the ComputationLayout below, + // make the result of the instruction be an array. + auto get_element0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0)); + auto negate = builder.AddInstruction(HloInstruction::CreateUnary( + constant0->shape(), HloOpcode::kNegate, get_element0)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + module.entry_computation()->ComputeProgramShape()); + + AssignLayouts(&module, &computation_layout); + + EXPECT_FALSE( + LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); + + EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape())); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual( + negate->shape(), computation_layout.result_layout().shape())); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual( + ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape())); +} + +TEST_F(LayoutAssignmentTest, TupleSelect) { + // Verify layouts of a select with tuple operands is assigned properly. + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + {0, 1}))); + auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( + test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, + {1, 0}))); + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({constant0, constant1})); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant0, constant1})); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + module.entry_computation()->ComputeProgramShape()); + Shape result_shape = + ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); + TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( + result_shape)); + + AssignLayouts(&module, &computation_layout); + + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); +} + +TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { + // Construct following computation which has conflicting layouts for two + // elements of a tuple which share the same source logicalb buffer: + // + // %constant = Constant(...) + // %inner_tuple = Tuple(%constant) + // %nested_tuple = Tuple(%inner_tuple, %inner_tuple) + // + // Result layout col-major for the first element and row-major for the + // second. This results in the conflict where the element of the inner_tuple + // needs to be both col and row major. This is resolved by deep-copying the + // tuple and assigning the layouts of the copied arrays as needed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + auto inner_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant})); + auto nested_tuple = builder.AddInstruction( + HloInstruction::CreateTuple({inner_tuple, inner_tuple})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + module.entry_computation()->ComputeProgramShape()); + Shape result_shape = nested_tuple->shape(); + *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); + TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( + result_shape)); + + LayoutAssignment layout_assignment(&computation_layout); + AssignLayouts(&module, &computation_layout); + + // Layout assignment should have deep copied the result of the computation to + // address the layout conflict. This results in several Tuple() and + // GetTupleElement() instructions. Running algebraic simplification should + // clean up the code to something like: + // + // %constant = Constant(...) layout={1,0} + // %tuple.0 = Tuple(%constant) layout=({1,0}) + // %copy = Copy(%constant) layout={0,1} # layout transposed + // %tuple.1 = Tuple(%copy) layout=({0,1}) + // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1})) + // + EXPECT_TRUE( + AlgebraicSimplifier(/*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return false; }) + .Run(&module) + .ValueOrDie()); + HloInstruction* root = module.entry_computation()->root_instruction(); + // Verify layout of the root and the root's operands. + EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), + root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}), + root->operand(1)->shape())); + + // Verify some of the structure of the HLO graph. + EXPECT_EQ(constant, root->operand(0)->operand(0)); + EXPECT_EQ(HloOpcode::kCopy, root->operand(1)->operand(0)->opcode()); + EXPECT_EQ(HloOpcode::kConstant, + root->operand(1)->operand(0)->operand(0)->opcode()); +} + +TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { + // param -> log -> reshape -> tanh + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1}); + Shape bshape = ShapeUtil::MakeShape(F32, {2, 1, 3}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param")); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param)); + auto reshape = + builder.AddInstruction(HloInstruction::CreateReshape(bshape, log)); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); + + HloModule module(TestName()); + HloComputation* computation = module.AddEntryComputation(builder.Build(tanh)); + + Shape ashape_with_layout(ashape); + Shape bshape_with_layout(bshape); + *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); + *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ashape_with_layout); + *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); + AssignLayouts(&module, &computation_layout); + + auto log_minor_to_major = + AsInt64Slice(log->shape().layout().minor_to_major()); + EXPECT_LT(PositionInContainer(log_minor_to_major, 1), + PositionInContainer(log_minor_to_major, 2)); + + auto reshape_minor_to_major = + AsInt64Slice(reshape->shape().layout().minor_to_major()); + EXPECT_LT(PositionInContainer(reshape_minor_to_major, 0), + PositionInContainer(reshape_minor_to_major, 2)); +} + +// Test whether LayoutAssignment assigns layouts to elementwise operations to +// keep linear indices valid across them, and to transpositions to make them +// bitcasts. +TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { + // param -> log -> transpose -> tanh + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); + Shape bshape = ShapeUtil::MakeShape(F32, {12, 42}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param")); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param)); + auto transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(bshape, log, {1, 0})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build(tanh)); + + Shape ashape_with_layout(ashape); + Shape bshape_with_layout(bshape); + *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ashape_with_layout); + *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); + AssignLayouts(&module, &computation_layout); + + EXPECT_TRUE( + LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(), + transpose->shape().layout())); + EXPECT_TRUE( + LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout())); +} + +// Test whether LayoutAssignment assigns layouts to transpositions to make them +// bitcasts. +TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { + // param -> broadcast -> transpose + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {3, 4}); + Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4}); + Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(bshape, param, {1, 2})); + auto transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); + HloModule module(TestName()); + HloComputation* computation = + module.AddEntryComputation(builder.Build(transpose)); + + Shape input_shape_with_layout(ashape); + Shape output_shape_with_layout(cshape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + *output_shape_with_layout.mutable_layout() = + LayoutUtil::MakeLayout({2, 1, 0}); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(input_shape_with_layout); + *computation_layout.mutable_result_layout() = + ShapeLayout(output_shape_with_layout); + AssignLayouts(&module, &computation_layout); + + EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(), + tensorflow::gtl::ArraySlice{0, 1, 2})); +} + +TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { + // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple + // \ / + // \-> tanh[3x4] -> broadcast2[2x3x4] -/ + // + // The layout of `transpose` is set to {1,0} because it provides a buffer to + // the computation result which has a fixed layout.. Therefore, `broadcast` + // (the operand of transpose) is expected to have layout {0,1} so that the + // transpose is a bitcast. Furthermore, `tanh` is expected to have the same + // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise. + Shape f32_4 = ShapeUtil::MakeShape(F32, {4}); + Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4}); + Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3}); + Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_4, "param")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32_34, param, {3})); + auto transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); + auto broadcast2 = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({transpose, broadcast2})); + HloModule module(TestName()); + HloComputation* computation = + module.AddEntryComputation(builder.Build(tuple)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + Shape param_shape_with_layout(f32_4); + Shape transpose_shape_with_layout(f32_43); + Shape broadcast2_shape_with_layout(f32_234); + *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0}); + *transpose_shape_with_layout.mutable_layout() = + LayoutUtil::MakeLayout({1, 0}); + *broadcast2_shape_with_layout.mutable_layout() = + LayoutUtil::MakeLayout({2, 1, 0}); + + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(param_shape_with_layout); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {transpose_shape_with_layout, broadcast2_shape_with_layout})); + AssignLayouts(&module, &computation_layout); + + EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(), + tensorflow::gtl::ArraySlice{0, 1})); + EXPECT_TRUE(ContainersEqual(transpose->shape().layout().minor_to_major(), + tensorflow::gtl::ArraySlice{1, 0})); + EXPECT_TRUE(ContainersEqual(tanh->shape().layout().minor_to_major(), + tensorflow::gtl::ArraySlice{0, 1})); +} + +// Add test which fails due to copy tuple. + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD new file mode 100644 index 0000000000..10468d9aae --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -0,0 +1,154 @@ +# Description: +# Libraries for helping construct LLVM IR for XLA backends. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "alias_analysis", + srcs = ["alias_analysis.cc"], + hdrs = ["alias_analysis.h"], + deps = [ + ":ir_array", + ":llvm_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "llvm_util", + srcs = ["llvm_util.cc"], + hdrs = ["llvm_util.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags", + "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/core:lib", + "@llvm//:core", + "@llvm//:support", + "@llvm//:target", + ], +) + +cc_library( + name = "ir_array", + srcs = ["ir_array.cc"], + hdrs = ["ir_array.h"], + deps = [ + ":llvm_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "llvm_loop", + srcs = ["llvm_loop.cc"], + hdrs = ["llvm_loop.h"], + deps = [ + ":ir_array", + ":llvm_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "loop_emitter", + srcs = ["loop_emitter.cc"], + hdrs = ["loop_emitter.h"], + deps = [ + ":ir_array", + ":llvm_loop", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "fused_ir_emitter", + srcs = ["fused_ir_emitter.cc"], + hdrs = ["fused_ir_emitter.h"], + deps = [ + ":ir_array", + ":llvm_util", + ":loop_emitter", + ":ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "ops", + srcs = ["ops.cc"], + hdrs = ["ops.h"], + deps = [ + ":ir_array", + ":llvm_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/README.md b/tensorflow/compiler/xla/service/llvm_ir/README.md new file mode 100644 index 0000000000..9fe7152477 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/README.md @@ -0,0 +1,2 @@ +Common utilites and abstractions for handling and emitting LLVM IR for XLA +backends. diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc new file mode 100644 index 0000000000..a552ea0218 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" + +#include + +#include "external/llvm/include/llvm/IR/MDBuilder.h" +#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace llvm_ir { + +void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, + llvm_ir::IrArray* array) { + BufferAllocation::Index buffer_index; + if (hlo.opcode() == HloOpcode::kParameter) { + // Parameters may alias with each other but may not alias with our temporary + // buffers. + buffer_index = kParameterAliasSet; + } else { + const std::set allocations = + assignment_.GetAllocations(&hlo, /*index=*/{}); + if (allocations.empty() || allocations.size() > 1) { + // Skip HLOs which don't have buffers a buffer assigned or for which the + // buffer can't be determined statically. We cannot determine their + // aliasing properties in these cases. + return; + } + buffer_index = allocations.begin()->index(); + } + + llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_index]; + if (alias_scope_md == nullptr) { + alias_scope_md = + GetAliasScopeMetadataForBuffer(buffer_index, GetAliasDomain()); + } + array->AddAliasScopeMetadata(alias_scope_md); + + llvm::MDNode*& noalias_md = noalias_metadata_[buffer_index]; + if (noalias_md == nullptr) { + noalias_md = GetNoaliasMetadataForBuffer(buffer_index, GetAliasDomain(), + assignment_, hlo); + } + array->AddNoaliasMetadata(noalias_md); + + // Parameters of the entry computation are never stored to, loading from a + // parameter pointer should always return the same result within a loop. + if (hlo.opcode() == HloOpcode::kParameter) { + const std::vector& parameter_instructions = + module_.entry_computation()->parameter_instructions(); + if (std::find(parameter_instructions.begin(), parameter_instructions.end(), + &hlo) != parameter_instructions.end()) { + array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + } + } +} + +llvm::MDNode* AliasAnalysis::GetAliasDomain() { + llvm::MDBuilder metadata_builder(*context_); + if (alias_domain_ == nullptr) { + alias_domain_ = metadata_builder.createAnonymousAliasScopeDomain(); + } + return alias_domain_; +} + +llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( + BufferAllocation::Index buffer_index, llvm::MDNode* domain) { + legacy_flags::AliasAnalysisFlags* flags = + legacy_flags::GetAliasAnalysisFlags(); + if (!flags->xla_emit_alias_scope) { + return nullptr; + } + + // While we could synthesize an alias.scope, doing so is not more profitable + // than LLVM's default behavior. + if (buffer_index == kParameterAliasSet) { + return nullptr; + } + + llvm::MDBuilder metadata_builder(domain->getContext()); + llvm::MDNode* scope = metadata_builder.createAliasScope( + AsStringRef(tensorflow::strings::StrCat("buffer: ", buffer_index)), + domain); + llvm::MDNode* scope_list = llvm::MDNode::get(domain->getContext(), scope); + return scope_list; +} + +llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( + BufferAllocation::Index buffer_index, llvm::MDNode* domain, + const BufferAssignment& assignment, const HloInstruction& hlo) { + legacy_flags::AliasAnalysisFlags* flags = + legacy_flags::GetAliasAnalysisFlags(); + if (!flags->xla_emit_alias_scope) { + return nullptr; + } + + // We want to construct a list of buffers which: + // + // 1. Do not alias the given buffer. + // 2. Will plausibly be used in the vicinity of the given buffer. + // + // Making the noalias set overly large will result in either a massive + // slowdown in LLVM or LLVM will just ignore the noalias set. + // + // A plausible list of instructions are: + // 1. Users of the given hlo. + // 2. Operands of users of the given hlo. + // 3. Operands of the given hlo. + // + // This set can be increased as we need. For now only consider top-level + // buffers (index = {}) not buffers nested within the instruction's + // operands/output which are not typically touched. + std::vector worklist; + auto add_buffers_to_worklist = + [&worklist, &assignment](const HloInstruction* instruction) { + for (const LogicalBuffer* buffer : + assignment.GetSourceBuffers(instruction, /*index=*/{})) { + worklist.push_back(buffer); + } + }; + + for (HloInstruction* user : hlo.users()) { + add_buffers_to_worklist(user); + for (HloInstruction* operand : user->operands()) { + add_buffers_to_worklist(operand); + } + } + + add_buffers_to_worklist(&hlo); + for (HloInstruction* operand : hlo.operands()) { + add_buffers_to_worklist(operand); + } + + std::unordered_set buffers; + for (const LogicalBuffer* buffer : worklist) { + // Skip buffers which cannot be added to the noalias set. + if (!assignment.HasAllocation(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kParameter) { + continue; + } + BufferAllocation::Index noalias_index = + assignment.GetAssignedAllocation(*buffer).index(); + // Our buffer must not noalias itself. + if (noalias_index != buffer_index) { + buffers.insert(noalias_index); + // Some instructions have too many operands, causing the noalias set to be + // too large. To reduce compilation time (b/31901575), truncate noalias + // sets to at most 500 elements. + // + // Future work: improvements to LLVM's scoped AA that avoid creating a + // MDNode set for every alias query can help to reduce the compilation + // time as well. + constexpr int kMaxNoAliasSetSize = 500; + if (buffers.size() >= kMaxNoAliasSetSize) { + break; + } + } + } + + // Don't bother constructing a noalias metadata node if it would be empty. + if (buffers.empty()) { + return nullptr; + } + + llvm::MDBuilder metadata_builder(domain->getContext()); + std::vector scopes; + for (BufferAllocation::Index noalias_index : buffers) { + llvm::MDNode* scope = metadata_builder.createAliasScope( + AsStringRef(tensorflow::strings::StrCat("buffer: ", noalias_index)), + domain); + scopes.push_back(scope); + } + llvm::MDNode* noalias_list = + llvm::MDNode::get(domain->getContext(), AsArrayRef(scopes)); + return noalias_list; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h new file mode 100644 index 0000000000..d8d45dd49b --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ + +#include + +#include "external/llvm/include/llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace llvm_ir { + +// Helper functionality used to augment the LLVM IR emitted with alias-scope +// metadata. +class AliasAnalysis { + public: + AliasAnalysis(const HloModule& module, const BufferAssignment& assignment, + llvm::LLVMContext* context) + : module_(module), assignment_(assignment), context_(context) {} + + // Augments IrArray with aliasing information. + void AddAliasingInformationToIrArray(const HloInstruction& hlo, + llvm_ir::IrArray* array); + + private: + // Returns a unique alias domain for this emitter. + llvm::MDNode* GetAliasDomain(); + + // Returns an alias.scope metadata node corresponding to a given buffer index. + llvm::MDNode* GetAliasScopeMetadataForBuffer( + BufferAllocation::Index buffer_index, llvm::MDNode* domain); + + // Returns a noalias metadata node corresponding to a given buffer index. + // + // |buffer_index| is the buffer index. + // + // |domain| corresponds to the alias scope domain as documented at + // http://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata + // + // |hlo| is the instruction we are computing a noalias set for. + llvm::MDNode* GetNoaliasMetadataForBuffer( + BufferAllocation::Index buffer_index, llvm::MDNode* domain, + const BufferAssignment& assignment, const HloInstruction& hlo); + + // The HLO module we are compiling for. + const HloModule& module_; + + // Assignment of the temporary buffers needed by the computation and their + // shape information. + const BufferAssignment& assignment_; + + // The LLVM context which we are using for IR emission. + llvm::LLVMContext* context_; + + // Holds the alias domain for this computation. + llvm::MDNode* alias_domain_ = nullptr; + + // Index in alias_scope_metadata_ and noalias_metadata_ for parameters + // of the entry computation which have special aliasing properties. + static constexpr int kParameterAliasSet = -1; + + // A map from a buffer index to metadata corresponding to its alias.scope + // metadata. The index kParameterAliasSet is used to hold aliasing + // information for parameters. + std::unordered_map alias_scope_metadata_; + + // A map from a buffer index to metadata corresponding to its noalias + // metadata. + std::unordered_map noalias_metadata_; +}; + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc new file mode 100644 index 0000000000..b259d34870 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" + +#include + +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using llvm_ir::IrArray; + +Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { + generators_[hlo] = + [=](const IrArray::Index& index) -> StatusOr { + if (generated_value_cache_[hlo].count(index.multidim()) > 0) { + llvm::Value* generated_value = + generated_value_cache_[hlo][index.multidim()]; + llvm::BasicBlock* generated_value_bb = nullptr; + if (auto* generated_instruction = + llvm::dyn_cast(generated_value)) { + generated_value_bb = generated_instruction->getParent(); + } + // Ideally, we should be able to reuse the cached generated value if it + // dominates the current insertion block. However, the check for dominance + // can be expensive and unreliable when the function is being constructed. + // + // It's also worth experimenting what if we don't do caching at all. + // LLVM's CSE or GVN should be able to easily merge common subexpressions + // that would be regenerated without caching. But this might increase the + // JIT compilation time. + if (generated_value_bb == nullptr || + generated_value_bb == ir_builder_->GetInsertBlock()) { + VLOG(3) << "The cached generated value is reused."; + return generated_value; + } + VLOG(3) << "The cached generated value can't be reuse, because it is at " + "a different BB (" + << llvm_ir::AsString(generated_value_bb->getName()) + << ") from the current insertion block (" + << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName()) + << ")."; + } + + TF_ASSIGN_OR_RETURN( + generated_value_cache_[hlo][index.multidim()], + elemental_emitter_->MakeElementGenerator(hlo, generators_)(index)); + return generated_value_cache_[hlo][index.multidim()]; + }; + return Status::OK(); +} + +Status FusedIrEmitter::HandleConstant(HloInstruction* constant, + const Literal& literal) { + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, + /*Name=*/""); + generators_[constant] = [=](const IrArray::Index& index) { + return IrArray(global, constant->shape()) + .EmitReadArrayElement(index, ir_builder_); + }; + + return Status::OK(); +} + +Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) { + // Lookup ir value for 'operand'. + auto it = gte_values_.find(operand); + if (it == gte_values_.end()) { + return Unimplemented( + "GetTupleElement fusion currently only supports" + " parameter operands, but found operand: %s", + operand->name().c_str()); + } + // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. + llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( + get_tuple_element->shape(), get_tuple_element->tuple_index(), + /*alignment=*/1, it->second, ir_builder_); + gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); + // Emit code to read base tuple element array (if non-tuple shaped). + if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { + generators_[get_tuple_element] = + [=](const IrArray::Index& index) -> StatusOr { + // TODO(b/34080002) Add aliasing information to tuple element IrArray. + return IrArray(tuple_element_ptr, get_tuple_element->shape()) + .EmitReadArrayElement(index, ir_builder_); + }; + } + return Status::OK(); +} + +Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { + generators_[parameter] = [=](const IrArray::Index& index) { + return parameter_arrays_[parameter->parameter_number()] + .EmitReadArrayElement(index, ir_builder_); + }; + // Store ir value for fusion operand associated with fusion parameter to be + // accessed by subsequent fused GetTupleElement instructions. + gte_values_.insert(std::make_pair( + parameter, + parameter_arrays_[parameter->parameter_number()].GetBasePointer())); + return Status::OK(); +} + +Status FusedIrEmitter::FinishVisit(HloInstruction* root) { + fused_root_ = root; + return tensorflow::Status::OK(); +} + +FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { + CHECK_NE(nullptr, fused_root_) + << "GetRootGenerator should be called after Accept."; + return generators_.at(fused_root_); +} + +FusedIrEmitter::Generator FusedIrEmitter::GetGenerator( + const HloInstruction* instruction) const { + return generators_.at(instruction); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h new file mode 100644 index 0000000000..303bb3ee6b --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ + +#include +#include + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// Unlike IrEmitter, this creates host functions which emit IR to generate the +// output element at the given index. It is used to generate fused operations. +class FusedIrEmitter : public DfsHloVisitorWithDefault { + public: + using Generator = llvm_ir::ElementGenerator; + + FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, + ElementalIrEmitter* elemental_emitter) + : parameter_arrays_(parameter_arrays), + elemental_emitter_(elemental_emitter), + ir_builder_(elemental_emitter->ir_builder()) {} + + Status DefaultAction(HloInstruction* hlo) override; + + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + + Status HandleParameter(HloInstruction* parameter) override; + + Status FinishVisit(HloInstruction* root) override; + + // Returns the generator function for the root of the fused computation. + Generator GetRootGenerator() const; + + // Returns the generator function for the given instruction. + Generator GetGenerator(const HloInstruction* instruction) const; + + private: + // Arrays of parameters of fusion instruction + tensorflow::gtl::ArraySlice parameter_arrays_; + + ElementalIrEmitter* elemental_emitter_; + + // This member will be set by FinishVisit and used in GetRootGenerator. + const HloInstruction* fused_root_ = nullptr; + + // Borrowed + llvm::IRBuilder<>* ir_builder_; + + // Map from instruction pointers to functions to generate elements of their + // outputs + std::unordered_map generators_; + + // Cache of generated values, lest we regenerate an element of a node with + // multiple outgoing edges + std::unordered_map, llvm::Value*>> + generated_value_cache_; + + // Stores ir values required to emit fused (and possibly nested) + // GetTupleElement instructions. + std::unordered_map gte_values_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc new file mode 100644 index 0000000000..c095dea7a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -0,0 +1,274 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" + +#include "external/llvm/include/llvm/IR/Constants.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +IrArray::Index::Index(llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) + : multidim_(ShapeUtil::Rank(shape)), + linear_(linear), + layout_(shape.layout()), + dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK(LayoutUtil::HasLayout(shape)) + << "Shape " << ShapeUtil::HumanStringWithLayout(shape) + << " should have a layout."; + int64 divisor = 1; + for (int64 dimension : layout_.minor_to_major()) { + int64 size_of_current_dimension = shape.dimensions(dimension); + // Emit IR instructions that compute + // (linear_index / divisor) % current_dimension + multidim_[dimension] = ir_builder->CreateURem( + ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)), + ir_builder->getInt64(size_of_current_dimension)); + divisor *= size_of_current_dimension; + } +} + +IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, + llvm::Value* linear, const Shape& shape) + : multidim_(multidim.begin(), multidim.end()), + linear_(linear), + layout_(shape.layout()), + dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_EQ(shape.dimensions_size(), multidim.size()); + CHECK(LayoutUtil::HasLayout(shape)) + << "Shape " << ShapeUtil::HumanStringWithLayout(shape) + << " should have a layout."; +} + +IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, + const Shape& shape, llvm::IRBuilder<>* ir_builder) + : multidim_(multidim.begin(), multidim.end()), + layout_(shape.layout()), + dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_EQ(shape.dimensions_size(), multidim.size()); + CHECK(LayoutUtil::HasLayout(shape)); + linear_ = Linearize(shape, ir_builder); +} + +IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) + : base_ptr_(base_ptr), shape_(&shape) { + TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + CHECK(base_ptr_->getType()->isPointerTy()); + int depth = 0; + element_type_ = + llvm::cast(base_ptr_->getType())->getElementType(); + while (llvm::ArrayType* array_type = + llvm::dyn_cast(element_type_)) { + element_type_ = array_type->getElementType(); + ++depth; + } + + if (ShapeUtil::Rank(*shape_) == 0) { + DCHECK(depth == 1 || depth == 0) << depth; + } else { + DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); + } +} + +// Returns whether given linear index valid on given shape. +bool IrArray::Index::LinearValidOnShape(const Shape& a) const { + auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_); + *b.mutable_layout() = layout_; + return linear_ != nullptr && + ContainersEqual( + ShapeUtil::StripDegenerateDimensions(a).dimensions(), + ShapeUtil::StripDegenerateDimensions(b).dimensions()) && + LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(), + ShapeUtil::StripDegenerateDimensions(b).layout()); +} + +IrArray::Index IrArray::Index::SourceIndexOfReshape( + const Shape& output_shape, const Shape& input_shape, + llvm::IRBuilder<>* builder) const { + const auto& target_index = *this; + CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape)); + llvm::Value* logical_linear_index = Linearize(output_shape, builder); + // Delinearizes logical_linear_index for the source array in row-major + // collapsed order. The first rank-1 indices are the remainder of the + // linear index by each dimension size. + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(input_shape, output_shape); + std::vector source_multidim_index(ShapeUtil::Rank(input_shape)); + for (int64 i = ShapeUtil::Rank(input_shape) - 1; i >= 0; --i) { + auto divisor = builder->getInt64(input_shape.dimensions(i)); + if (input_shape.dimensions(i) <= 1) { + source_multidim_index[i] = builder->getInt64(0); + } else { + // Search unmodified_dims for a pair whose first element is exactly "i". + // + // Because unmodified_dims are sorted by both "first" and "second", and + // "i" is monotonically decreasing, we avoid redundant searching by + // popping the back of unmodified_dims until the rear pair's first element + // <= i. If we stop precisely at "i", we find a match. + while (!unmodified_dims.empty() && unmodified_dims.back().first > i) { + unmodified_dims.pop_back(); + } + if (!unmodified_dims.empty() && unmodified_dims.back().first == i) { + source_multidim_index[i] = target_index[unmodified_dims.back().second]; + } else { + source_multidim_index[i] = + builder->CreateURem(logical_linear_index, divisor); + } + } + logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor); + } + + if (linear() != nullptr && + ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { + return Index(source_multidim_index, linear(), input_shape); + } + return Index(source_multidim_index); +} + +IrArray::Index IrArray::Index::SourceIndexOfTranspose( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const { + std::vector operand_multidim_index = + Permute(dimension_mapping, multidim()); + if (linear() != nullptr && + ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { + return Index(operand_multidim_index, linear(), operand_shape); + } + return Index(operand_multidim_index); +} + +llvm::Value* IrArray::Index::Linearize(const Shape& shape, + llvm::IRBuilder<>* builder) const { + // Each dimension is multiplied by the product of the sizes of all + // earlier dimensions and added to the accumulator logical_linear_index. + llvm::Value* logical_linear_index = builder->getInt64(0); + int64 multiplier = 1; + for (ssize_t i = size() - 1; i >= 0; --i) { + llvm::Value* addend = + builder->CreateMul((*this)[i], builder->getInt64(multiplier), "", + /*HasNUW=*/true, /*HasNSW=*/true); + logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= shape.dimensions(i); + } + return logical_linear_index; +} + +llvm::Value* IrArray::EmitArrayElementAddress( + const IrArray::Index& index, llvm::IRBuilder<>* ir_builder, + tensorflow::StringPiece name) const { + if (ShapeUtil::IsScalar(*shape_)) { + // Special handling of scalars: a scalar pretends to have the same value for + // every index, thus effectively implementing broadcasting of its value + // over higher-rank arrays. + return base_ptr_; + } + CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); + + std::vector actual_index; + bool is_implicit_broadcast = false; + // We perform broadcasting when the operand shape has dimension(s) of size + // 1. In this case we fix the index value for that dimension to zero. This + // effectively broadcasts along this dimension. + for (int64 i = 0; i < index.size(); ++i) { + auto dim = shape_->dimensions(i); + actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + is_implicit_broadcast |= dim == 1; + } + + if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + return ir_builder->CreateInBoundsGEP( + ir_builder->CreateBitCast( + base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder) + ->getPointerTo()), + {index.linear()}, llvm_ir::AsStringRef(name)); + } + + // "base_ptr_" has the type of "*" + // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element + // should be computed by + // + // getelementptr base_ptr_, 0, most major index, ..., most minor index + std::vector gep_indices(1, ir_builder->getInt64(0)); + for (int64 i = shape_->layout().minor_to_major_size() - 1; i >= 0; --i) { + int64 dimension = shape_->layout().minor_to_major(i); + gep_indices.push_back(actual_index[dimension]); + } + return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); +} + +llvm::Value* IrArray::EmitReadArrayElement(const Index& index, + llvm::IRBuilder<>* ir_builder, + tensorflow::StringPiece name) const { + llvm::Value* element_address = + EmitArrayElementAddress(index, ir_builder, name); + llvm::LoadInst* load = ir_builder->CreateLoad(element_address); + llvm_ir::SetTbaaForInstruction(load, GetShape(), + /*is_pointer_to=*/false); + for (const std::pair& kind_md_pair : metadata_) { + int kind = kind_md_pair.first; + llvm::MDNode* md = kind_md_pair.second; + load->setMetadata(kind, md); + } + return load; +} + +void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, + llvm::IRBuilder<>* ir_builder) const { + llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); + llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); + llvm_ir::SetTbaaForInstruction(store, GetShape(), + /*is_pointer_to=*/false); + for (const std::pair& kind_md_pair : metadata_) { + int kind = kind_md_pair.first; + CHECK_NE(kind, llvm::LLVMContext::MD_invariant_load); + llvm::MDNode* md = kind_md_pair.second; + store->setMetadata(kind, md); + } +} + +IrArray IrArray::CastToShape(const Shape& new_shape, + llvm::IRBuilder<>* ir_builder) const { + llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder); + return IrArray( + ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), + new_shape); +} + +/* static */ IrArray::Index IrArray::BumpIndex(const Index& index, + int64 which_dimension, + int64 addend, + llvm::IRBuilder<>* ir_builder) { + Index new_index = index; + new_index[which_dimension] = ir_builder->CreateAdd( + index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true, + /*HasNSW=*/true); + return new_index; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h new file mode 100644 index 0000000000..0b182267c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ + +#include +#include + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +// IrArray represents an XLA array at the LLVM IR level. This class +// encapsulates a base pointer to the buffer holding the array (as an LLVM +// Value) and the shape of the array. The class includes methods for emitting +// LLVM IR sequences which access elements of the array at a multidimensional +// index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts +// are supported. +class IrArray { + public: + // A multidimensional index into an IrArray. The index for dimension zero is + // first in the vector. This is the reverse order of the notation used for + // describing the dimensions of an array. That is, for a [4 x 3 x 2] array + // dimension zero has size 2, dimension one has size 3, and dimension two has + // size 4. Thus the index {1, 2, 3} indexes the last element of this [4 x 3 x + // 2] array. + // + // This may also keep a linear index and the layout and dimensions it was + // emitted for; if the shape where this `Index` is used matches, the linear + // index may be used, potentially sparing the cost of computing the + // multidimensional index, which LLVM DCE can delete. + class Index { + public: + // Constructs an empty zero-dimensional index. + Index() {} + + // Constructs an index of rank "size". Each dimension of the index is + // initialized to "value". + explicit Index(size_t size, llvm::Value* value = nullptr) + : multidim_(size, value) {} + + // Constructs an index from multi-dimensional index "multidim". The linear + // index is set to nullptr. + explicit Index(tensorflow::gtl::ArraySlice multidim) + : multidim_(multidim.begin(), multidim.end()) {} + + // Constructs an index from linear index "linear" and computes the + // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR + // builder to emit the index of each dimension in the multi-dimensional + // index. + // + // Precondition: "shape" has a layout. + Index(llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder); + + // Constructs an index from the given multi-dimensional index and the shape + // that it indexes into. Also, computes the linear index according to + // "shape". + // + // Precondition: "shape" has a layout. + Index(tensorflow::gtl::ArraySlice multidim, + const Shape& shape, llvm::IRBuilder<>* ir_builder); + + // Consturcts an index from both a multi-dimensional index and a linear + // index. "shape" has the same meaning as that in the constructor that takes + // only a linear index. + Index(tensorflow::gtl::ArraySlice multidim, + llvm::Value* linear, const Shape& shape); + + const std::vector& multidim() const { return multidim_; } + llvm::Value* linear() const { return linear_; } + + size_t size() const { return multidim().size(); } + + llvm::Value* operator[](size_t i) const { return multidim()[i]; } + llvm::Value*& operator[](size_t i) { return multidim()[i]; } + + void push_back(llvm::Value* value) { multidim().push_back(value); } + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + iterator begin() { return multidim().begin(); } + iterator end() { return multidim().end(); } + + const_iterator begin() const { return multidim().begin(); } + const_iterator end() const { return multidim().end(); } + + bool LinearValidOnShape(const Shape& a) const; + + // Given that "this" is the target index of a reshape from `operand_shape` + // to `shape`, returns the source index. + Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const; + + // Given that "this" is the target index of a transpose from `operand_shape` + // to `shape` with the given dimension mapping, returns the source index. + Index SourceIndexOfTranspose( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const; + + // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and + // returns the index into the sole dimension 0 of the new shape. + llvm::Value* Linearize(const Shape& shape, + llvm::IRBuilder<>* builder) const; + + private: + // Changing the multi-dimensional index invalidates the linear index. + std::vector& multidim() { + linear_ = nullptr; + return multidim_; + } + + std::vector multidim_; + + // These values are purely for efficiency; `multidim_` is enough to find the + // element at a given `Index`, but if a loop is emitted with a linear index + // space, that linear index can be saved in `linear_`, and the layout and + // dimensions of the shape the loop was emitted for in `layout_` and + // `dims_`, and if the `Index` is used in another array, and its layout and + // dimensions match, the linear index can be used, sparing the cost of + // computing `multidim_`, which LLVM DCE could potentially so delete. + // Modifying `multidim_` after construction nullifies `linear_`, lest it + // be used wrongly, as it would be valid no more. + // If a loop is emitted with a multidimensional index space, `linear_` would + // be null and `layout_` and `dims_` would be ignored. + llvm::Value* linear_ = nullptr; + Layout layout_; + std::vector dims_; + }; + + // Default constructor. Constructs an IrArray in a null status. + IrArray() : base_ptr_(nullptr), shape_(nullptr) {} + + // Construct an IrArray with the given base pointer and shape. base_ptr is a + // pointer type pointing to the first element(lowest address) of the array. + IrArray(llvm::Value* base_ptr, const Shape& shape); + + // Default implementations of copying and moving. + IrArray(IrArray&& other) = default; + IrArray(const IrArray& other) = default; + IrArray& operator=(IrArray&& other) = default; + IrArray& operator=(const IrArray& other) = default; + + llvm::Value* GetBasePointer() const { return base_ptr_; } + llvm::Type* GetElementLlvmType() const { return element_type_; } + + const Shape& GetShape() const { + CHECK(shape_ != nullptr); + return *shape_; + } + + // Emit a sequence of instructions to compute the address of the element in + // the given array at the given index. Returns the address of the element as + // an LLVM Value. + // + // The optional name is useful for debugging when looking at + // the emitted LLVM IR. + llvm::Value* EmitArrayElementAddress(const Index& index, + llvm::IRBuilder<>* ir_builder, + tensorflow::StringPiece name = "") const; + + // Emit IR to read an array element at the given index. Returns the read + // result (effectively, a Value loaded from memory). This method seamlessly + // handles scalar shapes by broadcasting their value to all indices (index is + // ignored). + // + // The optional name is useful for debugging when looking at + // the emitted LLVM IR. + llvm::Value* EmitReadArrayElement(const Index& index, + llvm::IRBuilder<>* ir_builder, + tensorflow::StringPiece name = "") const; + + // Emit IR to write the given value to the array element at the given index. + void EmitWriteArrayElement(const Index& index, llvm::Value* value, + llvm::IRBuilder<>* ir_builder) const; + + // Returns a new IrArray whose shape is "new_shape" and base pointer is a + // bitcast of the base pointer of "this" IrArray. + IrArray CastToShape(const Shape& new_shape, + llvm::IRBuilder<>* ir_builder) const; + + void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { + AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); + } + + void AddNoaliasMetadata(llvm::MDNode* noalias) { + AddMetadata(llvm::LLVMContext::MD_noalias, noalias); + } + + void AddInvariantLoad(llvm::MDNode* invariant_load) { + AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load); + } + + // Bumps the "which_dimension" value within the provided index by the provided + // addend. + static Index BumpIndex(const Index& index, int64 which_dimension, + int64 addend, llvm::IRBuilder<>* ir_builder); + + private: + // Add the specified LLVM IR metadata to loads/stores associated with this + // IrArray. + void AddMetadata(int kind, llvm::MDNode* md) { + InsertOrDie(&metadata_, kind, md); + } + + // Address of the base of the array as an LLVM Value. + llvm::Value* base_ptr_; + + // The LLVM type of the elements in the array. + llvm::Type* element_type_; + + // Shape of the XLA array. + const Shape* shape_; + + // The list of key/value pairs used when attaching metadata to emitted + // loads/stores for this array. They keys are the metadata kinds and the + // values are the metadata nodes. + std::map metadata_; +}; + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc new file mode 100644 index 0000000000..4ccded61e7 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" + +#include +#include + +#include "external/llvm/include/llvm/IR/Constants.h" +#include "external/llvm/include/llvm/IR/Function.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +ForLoop::ForLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* step) + : suffix_(suffix.ToString()), + start_index_(start_index), + end_index_(end_index), + step_(step), + insert_before_bb_(nullptr) {} + +/* static */ std::unique_ptr ForLoop::EmitForLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder) { + std::unique_ptr loop( + new ForLoop(suffix, start_index, end_index, step)); + loop->Emit(ir_builder); + return loop; +} + +void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { + // The preheader block is the block the builder is currently emitting + // code into. + preheader_bb_ = ir_builder->GetInsertBlock(); + + llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint(); + if (insert_point == preheader_bb_->end()) { + // We're emitting the loop at the end of a basic block. Verify there is no + // terminator (eg, branch) in the basic block. + CHECK_EQ(nullptr, preheader_bb_->getTerminator()); + + exit_bb_ = CreateBasicBlockWithSuffix("loop_exit", ir_builder); + } else { + // We're emitting the loop into the middle of a basic block. splitBasicBlock + // requires that this basic block be well-formed (have a terminator). + CHECK_NE(nullptr, preheader_bb_->getTerminator()); + + // Split the preheader to create an exit basic block. The exit basic block + // will contain all instructions at or after insert_point. + exit_bb_ = preheader_bb_->splitBasicBlock( + insert_point, GetNameWithSuffix("loop_exit").c_str()); + + // splitBasicBlock adds an unconditional branch between the split basic + // blocks. Remove it. An unconditional branch will be added below from the + // preheader to the header. + preheader_bb_->getTerminator()->eraseFromParent(); + } + insert_before_bb_ = exit_bb_; + + // Create remaining basic block which form the inside of the loop. + header_bb_ = CreateBasicBlockWithSuffix("loop_header", ir_builder); + body_bb_ = CreateBasicBlockWithSuffix("loop_body", ir_builder); + + // Function entry basic block. + // Emit alloca for the induction variable. We do this at the entry to the + // basic block to ensure the alloc only executes once per function (we could + // be emitting a nested loop). + llvm::Function* func = preheader_bb_->getParent(); + ir_builder->SetInsertPoint(&func->getEntryBlock(), + func->getEntryBlock().getFirstInsertionPt()); + llvm::Value* indvar_address = + ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr, + GetNameWithSuffix("invar_address").c_str()); + + // Preheader basic block. + // Initialize induction variable starting index. Create branch to the header. + ir_builder->SetInsertPoint(preheader_bb_); + ir_builder->CreateStore(start_index_, indvar_address); + // The preheader should not have a branch yet. + CHECK_EQ(preheader_bb_->getTerminator(), nullptr); + ir_builder->CreateBr(header_bb_); + + // Header basic block. + // Emit the loop conditional branch. Load and compare indvar with ending + // index and jump to loop exit if equal. Jump to body otherwise. + ir_builder->SetInsertPoint(header_bb_); + indvar_ = ir_builder->CreateLoad(indvar_address, + GetNameWithSuffix("indvar").c_str()); + llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_); + ir_builder->CreateCondBr(/*Cond=*/exit_cond, + /*True=*/exit_bb_, /*False=*/body_bb_); + + // Body basic block. + // Increment indvar, store indvar, and jump to header. + ir_builder->SetInsertPoint(body_bb_); + llvm::Value* step = step_; + llvm::Value* indvar = indvar_; + + llvm::Value* indvar_inc = + ir_builder->CreateAdd(indvar, step, "invar.inc", + /*HasNUW=*/true, /*HasNSW=*/true); + ir_builder->CreateStore(indvar_inc, indvar_address); + ir_builder->CreateBr(header_bb_); + + // Re-point the IR builder to the loop exit block. + ir_builder->SetInsertPoint(exit_bb_); +} + +string ForLoop::GetNameWithSuffix(tensorflow::StringPiece name) { + if (suffix_.empty()) { + return name.ToString(); + } else { + return tensorflow::strings::StrCat(name, ".", suffix_); + } +} + +llvm::BasicBlock* ForLoop::CreateBasicBlockWithSuffix( + tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) { + return CreateBasicBlock(insert_before_bb_, GetNameWithSuffix(name), + ir_builder); +} + +std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, + llvm::Value* start_index, + llvm::Value* end_index) { + if (inner_loop_body_bb_ != nullptr) { + // Create this loop inside the previous one. + ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); + } + std::unique_ptr loop = ForLoop::EmitForLoop( + suffix, start_index, end_index, ir_builder_->getInt64(1), ir_builder_); + + if (outer_loop_preheader_bb_ == nullptr) { + outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock(); + } + + if (outer_loop_exit_bb_ == nullptr) { + outer_loop_exit_bb_ = loop->GetExitBasicBlock(); + } + + inner_loop_body_bb_ = loop->GetBodyBasicBlock(); + + return loop; +} + +std::unique_ptr ForLoopNest::AddLoop(int64 start_index, + int64 end_index, + tensorflow::StringPiece suffix) { + CHECK_LE(start_index, end_index); + return AddLoop(suffix, ir_builder_->getInt64(start_index), + ir_builder_->getInt64(end_index)); +} + +IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, + tensorflow::StringPiece suffix) { + std::vector dimensions(ShapeUtil::Rank(shape)); + std::iota(dimensions.begin(), dimensions.end(), 0); + return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); +} + +IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( + const Shape& shape, tensorflow::gtl::ArraySlice dimensions, + tensorflow::StringPiece suffix) { + llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr); + for (int64 dimension : dimensions) { + std::unique_ptr loop = AddLoop( + /*start_index=*/0, + /*end_index=*/shape.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf( + "%s%lld", suffix.ToString().c_str(), dimension)); + index[dimension] = loop->GetIndVarValue(); + } + return index; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h new file mode 100644 index 0000000000..0cc82b040d --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -0,0 +1,230 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_LOOP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_LOOP_H_ + +#include +#include + +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +// A class for constructing a for-loop in LLVM IR. +class ForLoop { + public: + // Emit a for-loop at the current insert point of the given IRBuilder. + // + // start_index and end_index are the loop bounds (end_index is not inclusive). + // `step` is the increment of the loop index after each iteration. + // + // The current insert basic block of the builder is the preheader to the loop + // (see below for definition of basic block names). All instructions (if any) + // at or after the insert point in the insert basic block are moved to a newly + // created exit basic block. Instructions before the insert point remain in + // the insert BB: + // + // /--------------\ /----------------\ + // | insert BB | | insert BB | + // | ... | | (preheader BB) | + // | %foo = ... | | ... | + // insert point ->| %bar = ... | ===> | %foo = ... | + // | ... | \----------------/ + // \--------------/ | + // V + // [[ LOOP BBs ]] + // | + // V + // /--------------\ + // | exit BB | + // | %bar = ... | + // | ... | + // \--------------/ + // + // `suffix` is a string used to disambiguate variable and basic block names + // emitted in LLVM IR. This string is appended to the name of the induction + // variable value and each basic block created for the loop. The builder + // insert point is set to the end of the exit block after the function + // returns. + static std::unique_ptr EmitForLoop(tensorflow::StringPiece suffix, + llvm::Value* start_index, + llvm::Value* end_index, + llvm::Value* step, + llvm::IRBuilder<>* ir_builder); + + // The names of the blocks follow LLVM's conventions. Control flow amongst the + // blocks for the example C code looks like: + // + // for (int i = 0; i < n; ++i) { + // do_stuff(i); + // } + // + // /--------------\ + // | preheader BB | + // | i = 0 | + // \--------------/ + // | + // V + // /-------------\ + // | header BB |<-+ + // | if i < n: | | + // | goto body | | + // | else: | | + // | goto exit | | + // \-------------/ | + // | | | + // +--------+ | | + // | V | + // | /-------------\ | + // | | body BB | | + // | | dostuff(i) |--+ + // | | ++i | + // | \-------------/ + // | + // | /-------------\ + // +->| exit BB | + // \-------------/ + // + // Caller-emitted code to execute within the loop should be placed within the + // "body" basic block. + // + // Return pointers to various blocks in the loop. + llvm::BasicBlock* GetPreheaderBasicBlock() const { return preheader_bb_; } + llvm::BasicBlock* GetHeaderBasicBlock() const { return header_bb_; } + llvm::BasicBlock* GetBodyBasicBlock() const { return body_bb_; } + llvm::BasicBlock* GetExitBasicBlock() const { return exit_bb_; } + + // Return the Value representing the induction variable in the body basic + // block of the loop. + llvm::Value* GetIndVarValue() const { return indvar_; } + + private: + ForLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* step); + + // Emit the loop at the insert point of the builder. + void Emit(llvm::IRBuilder<>* ir_builder); + + llvm::BasicBlock* CreateBasicBlockWithSuffix(tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder); + + // Create a name for an LLVM construct appending the member suffix_ if it is + // set. + string GetNameWithSuffix(tensorflow::StringPiece name); + + string suffix_; + llvm::Value* start_index_; + llvm::Value* end_index_; + llvm::Value* step_; + + // To improve readability of the IR, we want the basic blocks to appear + // consecutively in the following order: preheader, header, body, loop, + // exit. The member insert_before_bb_ points to where the next basic block + // should be created to ensure this ordering. + llvm::BasicBlock* insert_before_bb_; + + llvm::BasicBlock* preheader_bb_; + llvm::BasicBlock* header_bb_; + llvm::BasicBlock* body_bb_; + llvm::BasicBlock* exit_bb_; + llvm::Value* indvar_; + + TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); +}; + +// A simple class for constructing nested for-loops. +class ForLoopNest { + public: + explicit ForLoopNest(llvm::IRBuilder<>* ir_builder) + : outer_loop_preheader_bb_(nullptr), + outer_loop_exit_bb_(nullptr), + inner_loop_body_bb_(nullptr), + ir_builder_(ir_builder) {} + + // Adds a loop to the nest. If no loop has been added yet then emit a loop at + // the current insert point of the given builder. If one or more loops have + // been added then emit loop inside the body of the last added loop. + std::unique_ptr AddLoop(tensorflow::StringPiece suffix, + llvm::Value* start_index, + llvm::Value* end_index); + + // A convenient wrapper of the other flavor of AddLoop. The given start and + // end index are constant. + std::unique_ptr AddLoop(int64 start_index, int64 end_index, + tensorflow::StringPiece suffix); + + // Add loops to iterate through the indices within the specified + // shape. The returned index collects the induction variables of the + // loops so that it will iterate through all coordinates within the + // specified shape. + // + // E.g. if you pass in a 2x3 shape, you will get back an index with + // two entries that are induction variables of the two loops that + // will be added. That index will iterate through the 6 coordinates + // within the shape. One possible order for that sequence would be: + // + // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) + IrArray::Index AddLoopsForShape(const Shape& shape, + tensorflow::StringPiece suffix); + + // Add a loop for each dimension in "dimensions". "suffix" is the + // name suffix of the indvar and basic blocks in this new loop nest. + // + // The return value is an index with the induction variables. The + // size equals the rank of shape and there is a null for each + // dimension that is not in "dimensions". + IrArray::Index AddLoopsForShapeOnDimensions( + const Shape& shape, tensorflow::gtl::ArraySlice dimensions, + tensorflow::StringPiece suffix); + + // Convenience methods which return particular basic blocks of the outermost + // or innermost loops. These methods return nullptr if no loops have been + // added yet. + llvm::BasicBlock* GetOuterLoopPreheaderBasicBlock() { + return outer_loop_preheader_bb_; + } + llvm::BasicBlock* GetOuterLoopExitBasicBlock() { return outer_loop_exit_bb_; } + llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; } + + private: + // The preheader and exit basic block of the outermost loop, or nullptr if no + // loop has been added yet. + llvm::BasicBlock* outer_loop_preheader_bb_; + llvm::BasicBlock* outer_loop_exit_bb_; + + // The body basic block of the most-recently added loop, or nullptr if no loop + // has been added yet. + llvm::BasicBlock* inner_loop_body_bb_; + + llvm::IRBuilder<>* ir_builder_; + + TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); +}; + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_LOOP_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc new file mode 100644 index 0000000000..d7a231db61 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -0,0 +1,471 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +#include +#include + +#include "external/llvm/include/llvm/IR/MDBuilder.h" +#include "external/llvm/include/llvm/IR/Operator.h" +#include "external/llvm/include/llvm/Target/TargetOptions.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +string AsString(const std::string& str) { + return string(str.data(), str.length()); +} + +llvm::StringRef AsStringRef(tensorflow::StringPiece str) { + return llvm::StringRef(str.data(), str.size()); +} + +string DumpModuleToString(const llvm::Module& module) { + std::string buffer_string; + llvm::raw_string_ostream ostream(buffer_string); + module.print(ostream, nullptr); + ostream.flush(); + return AsString(buffer_string); +} + +llvm::Value* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice overloaded_types, + llvm::IRBuilder<>* ir_builder) { + std::vector types; + for (auto type : overloaded_types) { + types.push_back(type); + } + llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::Function* intrinsic = + llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); + std::vector operands_vec; + for (auto operand : operands) { + operands_vec.push_back(operand); + } + return ir_builder->CreateCall(intrinsic, operands_vec); +} + +llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, + llvm::IRBuilder<>* ir_builder) { + llvm::Type* array_type = array->getType(); + CHECK(array_type->isPointerTy()); + llvm::PointerType* array_type_as_pointer = + llvm::cast(array_type); + VLOG(2) << "EmitBufferIndexingGEP with type=" + << llvm_ir::DumpToString(*array_type) + << " array=" << llvm_ir::DumpToString(*array) + << " index=" << llvm_ir::DumpToString(*index); + + return ir_builder->CreateInBoundsGEP( + array_type_as_pointer->getElementType(), array, + llvm::isa(array) + ? llvm::ArrayRef({ir_builder->getInt64(0), index}) + : index); +} + +llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, + llvm::IRBuilder<>* ir_builder) { + return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder); +} + +llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, + llvm::IRBuilder<>* ir_builder) { + switch (element_type) { + case PRED: + case S8: + case U8: + return ir_builder->getInt8Ty(); + case S16: + case U16: + return ir_builder->getInt16Ty(); + case S32: + case U32: + return ir_builder->getInt32Ty(); + case S64: + case U64: + return ir_builder->getInt64Ty(); + case F32: + return ir_builder->getFloatTy(); + case F64: + return ir_builder->getDoubleTy(); + // A Tuple contains an array of pointers. Use i8*. + case TUPLE: + // An Opaque is like a void*, use i8*. + case OPAQUE: + return ir_builder->getInt8PtrTy(); + default: + LOG(FATAL) << "unsupported type " << element_type; + } +} + +llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) { + llvm::Type* result_type = + PrimitiveTypeToIrType(shape.element_type(), ir_builder); + if (ShapeUtil::IsTuple(shape)) { + // A tuple buffer is an array of pointers. + result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); + } else { + for (int64 dimension : shape.layout().minor_to_major()) { + result_type = + llvm::ArrayType::get(result_type, shape.dimensions(dimension)); + } + } + return result_type; +} + +namespace { + +// Recursively construct a multidimensional LLVM constant which represents the +// given literal. The minor-to-major dimension ordering in the constant matches +// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 +// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a +// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type +// [4 x [3 x [2 x float]] will be returned. +// +// multi_index is a multidimensional index into the array. dimension_index is an +// index into the minor_to_major field in the literal shape. This determines +// which dimension is iterated over in this level of the recursion. Dimensions +// are iterated from most major down to most minor (highest dimension_index +// value down to zero). +llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, + std::vector* multi_index, + llvm::IRBuilder<>* ir_builder) { + const Shape& shape = literal.shape(); + llvm::Type* ir_element_type = + llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder); + if (dimension_index == -1) { + // Base case of the recursion. Index into the data field of the protobuf + // with the multi index. + llvm::Constant* value; + switch (shape.element_type()) { + case PRED: + value = llvm::ConstantInt::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case U8: + value = llvm::ConstantInt::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case S32: + value = llvm::ConstantInt::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case U32: + value = llvm::ConstantInt::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case S64: + value = llvm::ConstantInt::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case U64: + value = llvm::ConstantInt::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case F32: + value = llvm::ConstantFP::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + case F64: + value = llvm::ConstantFP::get( + ir_element_type, LiteralUtil::Get(literal, *multi_index)); + break; + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } + return value; + } + + // The dimension index starts at the one less than the rank of the array and + // decrements with each recursive call. We want to iterate through the + // dimensions in major-to-minor order as we recurse so just index into + // minor_to_major to get the dimension number for this level of the recursion. + int64 dimension = shape.layout().minor_to_major(dimension_index); + + // Recursively call LiteralToConstant to construct subarrays for the + // more-minor dimensions. Gather the subarrays into a vector for bundling into + // a new (higher-dimensional) ConstantArray. + std::vector elements; + for (int64 i = 0; i < shape.dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + elements.push_back(LiteralToConstant(literal, dimension_index - 1, + multi_index, ir_builder)); + } + + llvm::Type* element_type; + if (elements.empty()) { + element_type = ir_element_type; + for (int i = 0; i < dimension_index; ++i) { + int64 index = shape.layout().minor_to_major(i); + element_type = + llvm::ArrayType::get(element_type, shape.dimensions(index)); + } + } else { + element_type = elements[0]->getType(); + } + llvm::ArrayType* aggregate_type = + llvm::ArrayType::get(element_type, shape.dimensions(dimension)); + return llvm::ConstantArray::get(aggregate_type, elements); +} + +} // namespace + +llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, + llvm::IRBuilder<>* ir_builder) { + std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); + llvm::Constant* value = LiteralToConstant( + literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, + &multi_index, ir_builder); + return value; +} + +llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, + tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder, + int alignment) { + return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder, + alignment); +} + +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( + llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder, int alignment) { + llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP(); + llvm::Function* function = ir_builder->GetInsertBlock()->getParent(); + ir_builder->SetInsertPoint(&function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); + llvm::AllocaInst* alloca = + ir_builder->CreateAlloca(type, element_count, AsStringRef(name)); + if (alignment != 0) { + alloca->setAlignment(alignment); + } + ir_builder->restoreIP(insert_point); + return alloca; +} + +llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, + tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder) { + return llvm::BasicBlock::Create( + /*Context=*/ir_builder->getContext(), + /*Name=*/AsStringRef(name), + /*Parent=*/ir_builder->GetInsertBlock()->getParent(), + /*InsertBefore*/ insert_before); +} + +LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder, bool emit_else) { + llvm_ir::LlvmIfData if_data; + if_data.if_block = ir_builder->GetInsertBlock(); + if_data.true_block = CreateBasicBlock( + nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder); + if_data.false_block = + emit_else ? CreateBasicBlock(nullptr, + tensorflow::strings::StrCat(name, "-false"), + ir_builder) + : nullptr; + + // There is no reason this function cannot work without a + // terminator, that is just a different case that has not been + // implemented yet. It is a different case because splitBasicBlock + // requires a terminator. + CHECK_NE(nullptr, if_data.if_block->getTerminator()); + if_data.after_block = if_data.if_block->splitBasicBlock( + ir_builder->GetInsertPoint(), + AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + + // splitBasicBlock inserts an unconditional terminator that we have + // to remove as we want a conditional branch there. + if_data.if_block->getTerminator()->eraseFromParent(); + + ir_builder->SetInsertPoint(if_data.if_block); + ir_builder->CreateCondBr( + condition, if_data.true_block, + emit_else ? if_data.false_block : if_data.after_block); + + ir_builder->SetInsertPoint(if_data.true_block); + ir_builder->CreateBr(if_data.after_block); + + if (emit_else) { + ir_builder->SetInsertPoint(if_data.false_block); + ir_builder->CreateBr(if_data.after_block); + } + + ir_builder->SetInsertPoint(if_data.after_block, + if_data.after_block->getFirstInsertionPt()); + + return if_data; +} + +llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, + llvm::Value* lhs_value, llvm::Value* rhs_value, + llvm::IRBuilder<>* ir_builder) { + llvm::Value* comparison_result; + if (lhs_value->getType()->isIntegerTy()) { + comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value); + } else { + comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value); + } + // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 + // arrays. So we extend it to i8 so that it's addressable. + return ir_builder->CreateZExt( + comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder)); +} + +// Internal helper that is called from emitted code to log an int64 value with a +// tag. +static void LogS64(const char* tag, int64 value) { + LOG(INFO) << tag << " (int64): " << value; +} + +void EmitLogging(const char* tag, llvm::Value* value, + llvm::IRBuilder<>* ir_builder) { + llvm::FunctionType* log_function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), + {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false); + ir_builder->CreateCall( + log_function_type, + ir_builder->CreateIntToPtr( + ir_builder->getInt64(tensorflow::bit_cast(&LogS64)), + log_function_type->getPointerTo()), + {ir_builder->getInt64(tensorflow::bit_cast(tag)), value}); +} + +void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, + bool is_pointer_to) { + legacy_flags::LlvmUtilFlags* flags = legacy_flags::GetLlvmUtilFlags(); + if (!flags->xla_emit_tbaa) { + return; + } + + llvm::MDBuilder metadata_builder(instruction->getContext()); + llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA"); + string type_name; + if (is_pointer_to) { + type_name += "pointer-to "; + } + // Scalars do not have layout which makes it permissible to omit an explicit + // layout. To make sure that equivalent scalar shapes have the same TBAA, + // remove the (meaningless) explicit layout if one is present. + if (ShapeUtil::Rank(shape) == 0) { + LayoutUtil::ClearLayout(&shape); + } else { + CHECK(shape.has_layout()); + } + type_name += shape.ShortDebugString(); + llvm::MDNode* tbaa_node = + metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root); + instruction->setMetadata(llvm::LLVMContext::MD_tbaa, + metadata_builder.createTBAAStructTagNode( + tbaa_node, tbaa_node, /*Offset=*/0)); +} + +void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { + llvm::LLVMContext& context = load->getContext(); + llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); + llvm::Constant* alignment_constant = + llvm::ConstantInt::get(int64_ty, alignment); + llvm::MDBuilder metadata_builder(context); + auto* alignment_metadata = + metadata_builder.createConstant(alignment_constant); + load->setMetadata(llvm::LLVMContext::MD_align, + llvm::MDNode::get(context, alignment_metadata)); +} + +void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, + uint64_t dereferenceable_bytes) { + llvm::LLVMContext& context = load->getContext(); + llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); + llvm::Constant* dereferenceable_bytes_constant = + llvm::ConstantInt::get(int64_ty, dereferenceable_bytes); + llvm::MDBuilder metadata_builder(context); + auto* dereferenceable_bytes_metadata = + metadata_builder.createConstant(dereferenceable_bytes_constant); + load->setMetadata(llvm::LLVMContext::MD_dereferenceable, + llvm::MDNode::get(context, dereferenceable_bytes_metadata)); +} + +llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, + llvm::Instruction* inst) { + llvm::LLVMContext& context = inst->getParent()->getContext(); + llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context); + inst->setMetadata( + llvm::LLVMContext::MD_range, + llvm::MDNode::get( + context, + {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)), + llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))})); + return inst; +} + +string SanitizeIrName(string function_name) { + // Replace some characters that cannot occur in LLVM names with '_' + std::replace(function_name.begin(), function_name.end(), '.', '_'); + std::replace(function_name.begin(), function_name.end(), '%', '_'); + std::replace(function_name.begin(), function_name.end(), '-', '_'); + return function_name; +} + +void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { + builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); +} + +llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, + llvm::IRBuilder<>* builder) { + auto size = rotand->getType()->getPrimitiveSizeInBits(); + auto size_value = builder->getIntN(size, size); + auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); }; + return builder->CreateOr( + builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))), + builder->CreateLShr(rotand, mod(rotor))); +} + +int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { + unsigned pointer_size = data_layout.getPointerSize(); + return ShapeUtil::ByteSizeOf(shape, pointer_size); +} + +void SetFastMathFlags(llvm::FastMathFlags* fast_math_flags) { + auto* flags = legacy_flags::GetLlvmBackendFlags(); + if (flags->xla_precision_losing_optimizations) { + fast_math_flags->setAllowReciprocal(); + } + if (flags->xla_fast_math) { + fast_math_flags->setUnsafeAlgebra(); + } +} + +void SetTargetOptions(llvm::TargetOptions* options) { + auto* flags = legacy_flags::GetLlvmBackendFlags(); + options->LessPreciseFPMADOption = options->UnsafeFPMath = + flags->xla_fast_math || flags->xla_precision_losing_optimizations; + options->NoInfsFPMath = options->NoNaNsFPMath = flags->xla_fast_math; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h new file mode 100644 index 0000000000..56c26b3800 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -0,0 +1,228 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_ + +#include +#include +#include + +#include "external/llvm/include/llvm/ADT/StringRef.h" +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "external/llvm/include/llvm/IR/Module.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "external/llvm/include/llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace llvm { +class FastMathFlags; +class TargetOptions; +}; + +namespace xla { +namespace llvm_ir { + +// Convert a std::string (used by LLVM's interfaces) to string. +string AsString(const std::string& str); + +// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both +// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a +// string in memory. This method is used to feed strings to LLVM +// & Clang APIs that expect llvm::StringRef. +llvm::StringRef AsStringRef(tensorflow::StringPiece str); + +template +llvm::ArrayRef AsArrayRef(const std::vector& vec) { + return llvm::ArrayRef(vec.data(), vec.size()); +} + +template +llvm::ArrayRef AsArrayRef(const tensorflow::gtl::ArraySlice& slice) { + return llvm::ArrayRef(slice.data(), slice.size()); +} + +// Dump the given LLVM entity to a string. This works for Types and Values. +template +string DumpToString(const T& entity) { + std::string buffer_string; + llvm::raw_string_ostream ostream(buffer_string); + entity.print(ostream); + ostream.flush(); + return AsString(buffer_string); +} + +// Dump the given LLVM module to a string. This requires a function distinct +// from DumpToString because the signatures of the print() methods for Values +// and Modules are slightly different. +string DumpModuleToString(const llvm::Module& module); + +// Sanitizes the given name to be a valid LLVM IR value name. +string SanitizeIrName(string name); + +// Emits a call to the specified intrinsic with the given operands. Overloaded +// intrinsics (for example, "minnum") must include a type in overloaded_types +// for each overloaded type. Typically, overloaded intrinsics have only a single +// overloaded type. +llvm::Value* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, + tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice overloaded_types, + llvm::IRBuilder<>* ir_builder); + +// Convenience methods for emitting a GEP instruction that indexes into a buffer +// (1-dimensional array), equivalent to array[index]. The type is automatically +// determined from the element type of the array. The int64 index overload +// wraps the index in a i64 llvm::Value. +llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, + llvm::IRBuilder<>* ir_builder); +llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, + llvm::IRBuilder<>* ir_builder); + +// Returns the LLVM type which represents the given XLA primitive type. +llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, + llvm::IRBuilder<>* ir_builder); + +// Returns the LLVM type which represents the given XLA shape. For example, +// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. +llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder); + +// Converts a given literal to an IR Constant. Literals have known constant +// values at IR emission time. +llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, + llvm::IRBuilder<>* ir_builder); + +// Inserts an allocate of the requested type at the entry point of the +// function that the builder is currently building. The insert point +// of the builder is set to the same place after calling this function +// as before. +// +// This can be useful to avoid e.g. executing an alloca every time +// through a loop. +llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, + tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder, + int alignment = 0); + +// As EmitAllocaAtFunctionEntry, but allocates element_count entries +// intead of a single element. +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( + llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder, int alignment = 0); + +// Creates a basic block with the same context and funtion as for the +// builder. Inserts at the end of the function if insert_before is +// null. +llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, + tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder); + +// Struct with data on a conditional branch in a diamond shape created +// via EmitIfThenElse. +struct LlvmIfData { + // The block that has the conditional branch. + llvm::BasicBlock* if_block; + + // The block that is executed if the condition is true. + llvm::BasicBlock* true_block; + + // The block that is executed if the condition is false. + llvm::BasicBlock* false_block; + + // The block that follows after both the true_block and the + // false_block. + llvm::BasicBlock* after_block; +}; + +// Inserts a diamond-shaped if-then-else construct at the current +// insertion point of the builder. This involves splitting the current +// block into two blocks, at the insertion point, and introducing a +// true-block and a false-block that connect the two split pieces. The +// true-block is executed if the condition parameter evaluates to true +// and otherwise the false-block is executed. If `emit_else` is false, +// it jumps to the after-block rather than the false-block if the +// condition is false, and the returned `false_block` is null. +// +// Currently the insertion point of the builder must be a well-formed +// block with a terminator. If you need to use this for a +// non-terminated block, just make the function able to do that too. +LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder, bool emit_else = true); + +// Emits a compare operation between "lhs" and "rhs" with the given predicate, +// and then converts the result to i8 so that it is addressable. +llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, + llvm::Value* lhs, llvm::Value* rhs, + llvm::IRBuilder<>* ir_builder); + +// Emits a call that logs the given value with the given tag as a prefix. +// The provided tag and value are passed to a runtime logging call that is +// embedded in this translation unit when the emitted code is executed. +// +// This can be very useful for debugging generated programs in short order when +// developing new generated routines. +// +// Precondition: value must be an int64. +// Precondition: tag must be a stable pointer for the lifetime of the generated +// program (the constant pointer is burned in to the program). +void EmitLogging(const char* tag, llvm::Value* value, + llvm::IRBuilder<>* ir_builder); + +// Adds TBAA metadata to a load or store instruction using the given shape as +// it's type. The is_pointer_to parameter is used to indicate whether or not +// this instruction loads or stores a pointer to an array. +void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, + bool is_pointer_to); + +// Adds alignment metadata to a load instruction using the given alignment. +// The alignment refers to the result of the load, not the load itself. +void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment); + +// Adds dereferenceable metadata to a load instruction using the given +// the number of dereferenceable bytes. +// Dereferenceable refers to the result of the load, not the load itself. +void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, + uint64_t dereferenceable_bytes); + +// Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience. +llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, + llvm::Instruction* inst); + +void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); + +// Create a bitwise rotation of `rotand` by `rotor`. +llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, + llvm::IRBuilder<>* builder); + +// Returns the number of bytes within the shape. +int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout); + +// Set values in the given FastMathFlags struct according to our XLA flags. +void SetFastMathFlags(llvm::FastMathFlags* flags); + +// Set values in the given TargetOptions struct according to our XLA flags. +void SetTargetOptions(llvm::TargetOptions* options); + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc new file mode 100644 index 0000000000..9a128b2aa6 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, + llvm::IRBuilder<>* ir_builder) + : body_emitter_(body_emitter), shape_(shape), ir_builder_(ir_builder) {} + +LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, + const IrArray& target_array, + llvm::IRBuilder<>* ir_builder) + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) + -> ::tensorflow::Status { + // Convert target_element_generator to a BodyEmitter. + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + target_array.EmitWriteArrayElement(array_index, target_element, + ir_builder); + return tensorflow::Status::OK(); + }), + shape_(target_array.GetShape()), + ir_builder_(ir_builder) {} + +IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock() { + CHECK(!ShapeUtil::IsTuple(shape_)); + if (ShapeUtil::IsScalar(shape_)) { + // No loop needed, so set exit_bb_ to nullptr. + exit_bb_ = nullptr; + return IrArray::Index(); + } + + // Create loop nest with one for-loop for each dimension of the target shape. + // Loops are added from outermost to innermost order with the ForLoopNest + // class so emit loops in order from most-major dimension down to most-minor + // dimension (of the target shape). + ForLoopNest loop_nest(ir_builder_); + IrArray::Index array_index(shape_.dimensions_size()); + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { + int64 dimension = shape_.layout().minor_to_major(i); + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + + // Set IR builder insertion point to the loop body basic block of the + // innermost loop. + llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock(); + ir_builder_->SetInsertPoint(innermost_body_bb, + innermost_body_bb->getFirstInsertionPt()); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK_NOTNULL(exit_bb_); + + return array_index; +} + +tensorflow::Status LoopEmitter::EmitLoop() { + IrArray::Index array_index = EmitIndexAndSetExitBasicBlock(); + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + + // Set the insertion point of ir_builder_ to the loop exit, so that + // code emitted for later instructions will be correctly placed. + if (exit_bb_ != nullptr) { + ir_builder_->SetInsertPoint(exit_bb_); + } + return tensorflow::Status::OK(); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h new file mode 100644 index 0000000000..08171e9e9d --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LOOP_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LOOP_EMITTER_H_ + +#include + +#include "external/llvm/include/llvm/IR/BasicBlock.h" +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace llvm_ir { + +// A function type for emitting code that generates an element in the target +// array. The function gets a multi-dimensional index as its only input. This +// index specifies the target element for which a value needs to be computed. +// The function has to emit code to compute this value and return the resulting +// llvm::Value*. +using ElementGenerator = + std::function(const IrArray::Index& index)>; + +// Emits a loop for every element in the given shape. +class LoopEmitter { + public: + using BodyEmitter = + std::function; + + LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, + llvm::IRBuilder<>* ir_builder); + // Constructs a LoopEmitter from an element generator that generates each + // element of the given target array. + LoopEmitter(const ElementGenerator& target_element_generator, + const IrArray& target_array, llvm::IRBuilder<>* ir_builder); + LoopEmitter(const LoopEmitter&) = delete; + LoopEmitter& operator=(const LoopEmitter&) = delete; + virtual ~LoopEmitter() = default; + + // Emits a loop nest (with a yet-to-be-filled loop body) that iterates through + // every element in the given shape. Returns the multi-dimensional index that + // specifies the element. + virtual IrArray::Index EmitIndexAndSetExitBasicBlock(); + + // Emits a complete loop nest for every element in the given shape. + tensorflow::Status EmitLoop(); + + protected: + // An IR emitter that generates the loop body. + BodyEmitter body_emitter_; + + // The shape that the emitted loop iterates through. + Shape shape_; + + // Points to the exit block of the emitted loop. If the given shape is + // scalar, no loops are emitted and exit_bb_ is nullptr in that case. + llvm::BasicBlock* exit_bb_; + + llvm::IRBuilder<>* ir_builder_; +}; + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc new file mode 100644 index 0000000000..e01d25d250 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" + +#include +#include +#include + +#include "external/llvm/include/llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) { + CHECK(ShapeUtil::IsScalar(pred.GetShape())); + + llvm::LoadInst* pred_value = + ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder->CreateICmpNE( + pred_value, + llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0), + "boolean_predicate"); + + VLOG(2) << "HandleSelect for tuple:"; + VLOG(2) << " pred_value: " << DumpToString(*pred_value); + VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); + + for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { + std::vector element_index = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; + llvm::Value* on_true_element_address = + ir_builder->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = ir_builder->CreateLoad( + on_true_element_address, + tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + llvm::Value* on_false_element_address = + ir_builder->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = ir_builder->CreateLoad( + on_false_element_address, + tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + + llvm::Value* output_element_address = + ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); + ir_builder->CreateStore( + ir_builder->CreateSelect( + pred_cond, on_true_element, on_false_element, + tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + output_element_address); + } +} + +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder) { + for (size_t i = 0; i < operands.size(); ++i) { + ir_builder->CreateStore( + ir_builder->CreatePointerCast(operands[i], + PrimitiveTypeToIrType(TUPLE, ir_builder)), + ir_builder->CreateInBoundsGEP( + tuple.GetBasePointer(), + {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + } +} + +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder) { + llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( + operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); + llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + SetTbaaForInstruction(src_buffer, target_shape, /*is_pointer_to=*/true); + SetAlignmentMetadataForLoad(src_buffer, alignment); + llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder); + llvm::Value* ret_val = + ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + return ret_val; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h new file mode 100644 index 0000000000..af4063c340 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ + +#include "external/llvm/include/llvm/IR/IRBuilder.h" +#include "external/llvm/include/llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +// Selection among tuples is special in how it's lowered, because a tuple is not +// an HLO array. +// +// tuple_on_true tuple_on_false +// | | +// V V +// ------------------------ ------------------------ +// | address of element 0 | | address of element 0 | +// |----------------------| |----------------------| +// | address of element 1 | | address of element 1 | +// |----------------------| |----------------------| +// | address of element 2 | | address of element 2 | +// ------------------------ ------------------------ +// \ / +// \ / +// ---------- +// pred ---------> | select | +// ---------- +// | +// V +// output ----> ------------------------ +// | address of element 0 | +// |----------------------| +// | address of element 1 | +// |----------------------| +// | address of element 2 | +// ------------------------ +// +// Only the addresses are copied to the output. For each element, we emit a copy +// of the address from the corresponding element in either +// tuple_on_true or tuple_on_false: +// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. A GetTupleElement instruction +// forwards the pointer to underlying tuple element buffer at the given index. +// Returns an llvm value representing a pointer to the tuple element buffer. +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc new file mode 100644 index 0000000000..38465e37e7 --- /dev/null +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -0,0 +1,543 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/local_service.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/user_computation.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +LocalExecuteOptions& LocalExecuteOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* LocalExecuteOptions::platform() const { + return platform_; +} + +LocalExecuteOptions& LocalExecuteOptions::set_device_ordinal( + int device_ordinal) { + device_ordinal_ = device_ordinal; + return *this; +} + +int LocalExecuteOptions::device_ordinal() const { return device_ordinal_; } + +LocalExecuteOptions& LocalExecuteOptions::set_allocator( + DeviceMemoryAllocator* allocator) { + allocator_ = allocator; + return *this; +} + +DeviceMemoryAllocator* LocalExecuteOptions::allocator() const { + return allocator_; +} + +LocalExecuteOptions& LocalExecuteOptions::set_stream( + perftools::gputools::Stream* stream) { + stream_ = stream; + return *this; +} + +perftools::gputools::Stream* LocalExecuteOptions::stream() const { + return stream_; +} + +LocalExecuteOptions& LocalExecuteOptions::set_execution_profile( + ExecutionProfile* profile) { + profile_ = profile; + return *this; +} + +ExecutionProfile* LocalExecuteOptions::execution_profile() const { + return profile_; +} + +LocalExecuteOptions& LocalExecuteOptions::set_result_layout( + const Shape& shape_with_layout) { + has_result_shape_with_layout_ = true; + result_shape_with_layout_ = shape_with_layout; + return *this; +} + +const Shape* LocalExecuteOptions::result_layout() const { + return has_result_shape_with_layout_ ? &result_shape_with_layout_ : nullptr; +} + +/* static */ StatusOr> LocalService::NewService( + perftools::gputools::Platform* platform) { + ServiceOptions default_options; + default_options.set_platform(platform); + return NewService(default_options); +} + +/* static */ StatusOr> LocalService::NewService( + const ServiceOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr backend, + Backend::CreateBackend(platform, options.number_of_replicas())); + + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, + CreateComputeConstantBackend()); + std::unique_ptr service(new LocalService( + std::move(backend), std::move(compute_constant_backend))); + return std::move(service); +} + +LocalService::LocalService(std::unique_ptr execute_backend, + std::unique_ptr compute_constant_backend) + : Service(std::move(execute_backend), std::move(compute_constant_backend)) { + runs_in_client_process_ = true; +} + +tensorflow::Status LocalService::ResolveArguments( + const tensorflow::gtl::ArraySlice arguments, + int device_ordinal, + std::vector* argument_ptrs) { + TF_ASSIGN_OR_RETURN(std::vector arg_allocations, + ResolveAndValidateArguments( + arguments, execute_backend_.get(), device_ordinal)); + argument_ptrs->resize(arg_allocations.size()); + for (int i = 0; i < arguments.size(); ++i) { + const Allocation& allocation = *arg_allocations[i]; + (*argument_ptrs)[i] = allocation.device_memory(); + } + return tensorflow::Status::OK(); +} + +namespace { +// Returns the space required to allocate a shape. If +// allocate_space_for_deep_copy the space includes all sub-buffers of +// a tuple. +int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, + TransferManager* transfer_manager) { + int64 size = 0; + // TODO(b/33492279) remove once no devices represent result tuples as + // contiguous buffers. + if (allocate_space_for_deep_copy) { + TF_CHECK_OK(ShapeUtil::ForEachSubshape( + shape, [&size, transfer_manager](const Shape& subshape, + const ShapeIndex& /*index*/) { + size += transfer_manager->GetByteSizeRequirement(subshape); + return tensorflow::Status::OK(); + })); + } + return size; +} +} // namespace + +StatusOr LocalService::AllocateBufferOnDevice( + const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) { + int64 allocation_size = RequiredSpace(shape, allocate_space_for_deep_copy, + execute_backend_->transfer_manager()); + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, + execute_backend_->memory_allocator()->Allocate( + device_ordinal, allocation_size)); + + return allocation_tracker_.Register( + execute_backend_.get(), device_ordinal, allocation, shape, + tensorflow::strings::StrCat("AllocateBufferOnDevice of size ", + allocation_size)); +} + +StatusOr> LocalService::ExecuteLocally( + const ComputationHandle& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options) { + return ExecuteLocallyInternal(computation, arguments, options, + /*preallocated_result_buffer=*/nullptr); +} + +tensorflow::Status LocalService::ExecuteLocally( + const ComputationHandle& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, ShapedBuffer* result_buffer) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr null_buffer, + ExecuteLocallyInternal(computation, arguments, options, result_buffer)); + // Because the result is written into result_buffer, a null ShapedBuffer + // pointer should have been returned. + CHECK_EQ(nullptr, null_buffer.get()); + return tensorflow::Status::OK(); +} + +StatusOr> +LocalService::CompileAheadOfTime( + const ComputationHandle& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const Shape& result_layout, const AotCompilationOptions& options) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + computation_tracker_.BuildHloModule(versioned_handle, + /*include_unused_parameters=*/true)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + auto module_config = MakeUnique(*program_shape); + auto* computation_layout = module_config->mutable_entry_computation_layout(); + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_layout = *argument_layouts[i]; + if (ShapeUtil::IsTuple(argument_layout)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_layout)); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + result_layout)); + + return execute_backend_->compiler() + ->CompileAheadOfTime(std::move(hlo_module), std::move(module_config), + MakeHloDumper(), options) + .ConsumeValueOrDie(); +} + +tensorflow::Status LocalService::ValidateExecuteOptions( + const ProgramShape& program_shape, + tensorflow::gtl::ArraySlice argument_layouts, + const LocalExecuteOptions& options, + const ShapedBuffer* preallocated_result_buffer) { + if (argument_layouts.size() != program_shape.parameters_size()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + program_shape.parameters_size(), argument_layouts.size()); + } + + if (options.stream()) { + if (!options.stream()->ok()) { + return InvalidArgument("stream is uninitialized or in an error state"); + } + + // Check stream matches service platform. + const se::Platform* stream_platform = + options.stream()->parent()->platform(); + if (stream_platform != execute_backend_->platform()) { + return InvalidArgument( + "stream is for platform %s, but service targets platform %s", + stream_platform->Name().c_str(), + execute_backend_->platform()->Name().c_str()); + } + + // Cannot specify platform or device_ordinal with a stream. The stream + // determines these values. + if (options.device_ordinal() >= 0) { + return InvalidArgument( + "cannot set both device ordinal and stream options in " + "LocalExecuteOptions; the stream determines the device ordinal"); + } + if (options.platform()) { + return InvalidArgument( + "cannot set both platform and stream options in " + "LocalExecuteOptions; the stream determines the platform"); + } + } + if (options.platform() && + options.platform() != execute_backend_->platform()) { + return InvalidArgument( + "service platform (%s) does not match platform set in " + "LocalExecuteOptions (%s)", + execute_backend_->platform()->Name().c_str(), + options.platform()->Name().c_str()); + } + + // TODO(cwhipkey): validate the thread pool provided? + + if (!options.allocator()) { + return InvalidArgument("an allocator must be provided to ExecuteLocally"); + } + + if (options.allocator()->platform() != execute_backend_->platform()) { + return InvalidArgument( + "allocator platform (%s) does not match service platform (%s)", + options.allocator()->platform()->Name().c_str(), + execute_backend_->platform()->Name().c_str()); + } + + if (preallocated_result_buffer != nullptr) { + if (options.result_layout()) { + return InvalidArgument( + "cannot set both result ShapedBuffer and result layout; the result " + "ShapedBuffer determines the result layout"); + } + if (!ShapeUtil::Compatible(preallocated_result_buffer->shape(), + program_shape.result())) { + return InvalidArgument( + "result ShapedBuffer of shape %s not compatible with computation " + "result shape %s", + ShapeUtil::HumanString(preallocated_result_buffer->shape()).c_str(), + ShapeUtil::HumanString(program_shape.result()).c_str()); + } + } + if (options.result_layout()) { + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(*options.result_layout(), + program_shape.result())); + } + + // Check that all argument layouts are valid and the right shape. + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_shape = *argument_layouts[i]; + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { + return InvalidArgument( + "invalid argument shape for argument %d, expected %s, got %s", i, + ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(argument_shape).c_str()); + } + } + + return tensorflow::Status::OK(); +} + +StatusOr> LocalService::ExecuteLocallyInternal( + const ComputationHandle& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, + ShapedBuffer* preallocated_result_buffer) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + // Determine device ordinal the computation will run on. + int device_ordinal; + if (options.device_ordinal() >= 0) { + device_ordinal = options.device_ordinal(); + } else if (options.stream()) { + device_ordinal = options.stream()->parent()->device_ordinal(); + } else { + device_ordinal = execute_backend_->default_device_ordinal(); + } + + // Check that all arguments are on the right platform and device ordinal. + std::vector argument_layouts(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + auto argument = arguments[i]; + if (argument->platform() != execute_backend_->platform() || + argument->device_ordinal() != device_ordinal) { + return InvalidArgument( + "computation to run on device %s but argument %d is on " + "device %s:%d", + execute_backend_->device_name(device_ordinal).c_str(), i, + argument->platform()->Name().c_str(), argument->device_ordinal()); + } + argument_layouts[i] = &argument->shape(); + } + + TF_RETURN_IF_ERROR(ValidateExecuteOptions( + *program_shape, argument_layouts, options, preallocated_result_buffer)); + + // Construct computation layout from the argument layouts. + auto module_config = MakeUnique(*program_shape); + module_config->set_has_hybrid_result(true); + module_config->set_replica_count(execute_backend_->Replicas().size()); + std::vector argument_buffers; + auto* computation_layout = module_config->mutable_entry_computation_layout(); + for (int i = 0; i < arguments.size(); ++i) { + const ShapedBuffer* argument = arguments[i]; + if (ShapeUtil::IsTuple(argument->shape())) { + return Unimplemented("tuple arguments not supported yet"); + } + argument_buffers.push_back(argument->buffer(/*index=*/{})); + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument->shape())); + } + if (options.result_layout()) { + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *options.result_layout())); + } else if (preallocated_result_buffer != nullptr) { + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + preallocated_result_buffer->shape())); + } else { + computation_layout->mutable_result_layout()->SetToDefaultLayout(); + } + + ExecutableRunOptions run_options; + run_options.set_allocator(options.allocator()); + run_options.set_inter_op_thread_pool( + execute_backend_->inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + execute_backend_->eigen_intra_op_thread_pool_device()); + + // "acquired_stream" owns the stream used for execution if no stream is given. + std::unique_ptr acquired_stream; + if (options.stream()) { + run_options.set_stream(options.stream()); + } else { + se::StreamExecutor* stream_executor; + if (options.device_ordinal() >= 0) { + TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( + options.device_ordinal())); + } else { + stream_executor = execute_backend_->default_stream_executor(); + } + TF_ASSIGN_OR_RETURN(acquired_stream, + execute_backend_->AcquireStream(stream_executor)); + run_options.set_stream(acquired_stream.get()); + } + auto stream_releaser = + ::tensorflow::gtl::MakeCleanup([this, &acquired_stream]() { + if (acquired_stream != nullptr) { + execute_backend_->ReleaseStream(std::move(acquired_stream)); + } + }); + + ExecutionProfile* profile = options.execution_profile(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr executable, + BuildAndCacheExecutable(versioned_handle, std::move(module_config), + argument_buffers, execute_backend_.get(), + run_options.stream()->parent(), profile)); + + if (preallocated_result_buffer == nullptr) { + return Service::ExecuteOnStreamWrapper< + StatusOr>>( + executable.get(), &run_options, profile, + [&arguments](Executable* executable, + const ExecutableRunOptions* run_options, + HloExecutionProfile* hlo_execution_profile) { + return executable->ExecuteOnStream(run_options, arguments, + hlo_execution_profile); + }); + } else { + TF_RETURN_IF_ERROR(Service::ExecuteOnStreamWrapper( + executable.get(), &run_options, profile, + [&arguments, preallocated_result_buffer]( + Executable* executable, const ExecutableRunOptions* run_options, + HloExecutionProfile* hlo_execution_profile) { + return executable->ExecuteOnStream(run_options, arguments, + preallocated_result_buffer, + hlo_execution_profile); + })); + // To satisfy the return value type, Return a null ShapedBuffer pointer. + return std::unique_ptr(); + } +} + +StatusOr> LocalService::CompileExecutable( + const ComputationHandle& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const Shape* result_layout, int device_ordinal, bool has_hybrid_result) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + // Validate incoming layouts. + if (argument_layouts.size() != program_shape->parameters_size()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + program_shape->parameters_size(), argument_layouts.size()); + } + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_shape = *argument_layouts[i]; + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { + return InvalidArgument( + "invalid argument shape for argument %d, expected %s, got %s", i, + ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), + ShapeUtil::HumanString(argument_shape).c_str()); + } + } + if (result_layout != nullptr) { + TF_RETURN_IF_ERROR( + ValidateResultShapeWithLayout(*result_layout, program_shape->result())); + } + + // Construct computation layout from the argument layouts. + auto module_config = MakeUnique(*program_shape); + module_config->set_has_hybrid_result(has_hybrid_result); + module_config->set_replica_count(execute_backend_->Replicas().size()); + auto* computation_layout = module_config->mutable_entry_computation_layout(); + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& shape = *argument_layouts[i]; + if (ShapeUtil::IsTuple(shape)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + shape)); + } + if (result_layout != nullptr) { + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *result_layout)); + } else { + computation_layout->mutable_result_layout()->SetToDefaultLayout(); + } + + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + execute_backend_->stream_executor(device_ordinal)); + + std::vector argument_buffers( + argument_layouts.size()); + return BuildExecutable(versioned_handle, std::move(module_config), + /*executable_for_compute_constant=*/false, + argument_buffers, execute_backend_.get(), executor); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h new file mode 100644 index 0000000000..3e160a0201 --- /dev/null +++ b/tensorflow/compiler/xla/service/local_service.h @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_ + +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Computation execution options which may be set by the client when executing +// locally (via LocalClient::ExecuteLocally). +class LocalExecuteOptions { + public: + // Specifies the allocator to use during execution. Execution will fail if no + // allocator is provided. + LocalExecuteOptions& set_allocator(DeviceMemoryAllocator* allocator); + DeviceMemoryAllocator* allocator() const; + + // If set, this is the platform to run the computation on. This must match + // the underlying platform of the service. A value of nullptr means the + // platform is not set. + // TODO(b/28616830): Support multiple platforms. + LocalExecuteOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // If set, this is the device to run the computation on. Valid device_ordinal + // values are: 0 to # of devices - 1. These values are identical to the + // device ordinal values used by StreamExecutor. A value of < 0 means the + // ordinal is not set. + LocalExecuteOptions& set_device_ordinal(int device_ordinal); + int device_ordinal() const; + + // If set, this is the stream to run the computation on. The platform of the + // stream must match the service's platform. The device ordinal + // option (if set) must match the stream's device. A value of nullptr means + // the stream is not set. + LocalExecuteOptions& set_stream(perftools::gputools::Stream* stream); + perftools::gputools::Stream* stream() const; + + // If set, collect profile information during execution and fill the given + // ExecutionProfile object with the profile data. A value of nullptr means + // the profile is not set. + LocalExecuteOptions& set_execution_profile(ExecutionProfile* profile); + ExecutionProfile* execution_profile() const; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accomodate tuple result shapes. A value of nullptr + // means the shape is not set. + LocalExecuteOptions& set_result_layout(const Shape& shape_with_layout); + const Shape* result_layout() const; + + private: + DeviceMemoryAllocator* allocator_ = nullptr; + perftools::gputools::Platform* platform_ = nullptr; + int device_ordinal_ = -1; + perftools::gputools::Stream* stream_ = nullptr; + ExecutionProfile* profile_ = nullptr; + + bool has_result_shape_with_layout_ = false; + Shape result_shape_with_layout_; +}; + +// Service implementation that extends the XLA Service to leverage running +// in the same process as the client. +class LocalService : public Service { + public: + // Factory for creating a LocalService. The parameter platform is the platform + // that the service should target. If platform is null then the default + // platform is used. + static StatusOr> NewService( + perftools::gputools::Platform* platform); + static StatusOr> NewService( + const ServiceOptions& options); + + // For an array of arguments, validate that each is placed on the + // specified device_ordinal, and return the DeviceMemoryBase + // corresponding to each argument. + tensorflow::Status ResolveArguments( + const tensorflow::gtl::ArraySlice arguments, + int device_ordinal, + std::vector* argument_ptrs); + + // Return a handle to a buffer large enough to hold shape, allocated + // on device_ordinal. If allocate_space_for_deep_copy, the buffer is + // large enough to hold all sub-buffers of a tuple shape, otherwise + // it is only as large as the top-level tuple pointer array. + StatusOr AllocateBufferOnDevice( + const Shape& shape, int device_ordinal, + bool allocate_space_for_deep_copy); + + // Execute the given computation with the given arguments and options with + // zero-copy data handling of arguments and result. + StatusOr> ExecuteLocally( + const ComputationHandle& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options); + + // Overload which writes the result into the given ShapedBuffer "result". + // Due to aliasing, not all buffers which comprise "result" may be utilized + // in the computation and thus be uninitialized. The |ShapedBuffer::buffer| + // or |ShapedBuffer::mutable_buffer| methods should be used to map an index to + // the initialized buffer. + // + // For example: + // Let 'result' be a ShapedBuffer holding a tuple with the same element, + // 'x', twice: (x, x). It is incorrect to assume that the second buffer + // which comprises 'result' is initialized. Instead, a mapping has been + // added to 'result' which can be used to recover the correct buffer. + // In this case, result->buffer({0}) should be used to extract the address of + // the first tuple element while result->buffer({1}) should be used for the + // second. + tensorflow::Status ExecuteLocally( + const ComputationHandle& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, ShapedBuffer* result_buffer); + + // Compiles the computation for ahead-of-time execution. This is intended for + // use in static compilation. See |LocalClient::CompileAheadOfTime| for + // additional details. + StatusOr> CompileAheadOfTime( + const ComputationHandle& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const Shape& result_layout, const AotCompilationOptions& Options); + + // Builds an Executable with the given argument layouts and options. If + // result_layout is non-null, then the executable is compiled to produce a + // result of the given layout. + StatusOr> CompileExecutable( + const ComputationHandle& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const Shape* result_layout, int device_ordinal, bool has_hybrid_result); + + private: + explicit LocalService(std::unique_ptr backend, + std::unique_ptr compute_constant_backend); + LocalService(const LocalService&) = delete; + void operator=(const LocalService&) = delete; + + // Internal helper for executing a computation. If result_buffer is null then + // the result is returned as a ShapedBuffer. If result_buffer is non-null then + // the result is written into result_buffer and a null ShapedBuffer pointer is + // returned. + StatusOr> ExecuteLocallyInternal( + const ComputationHandle& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, + ShapedBuffer* preallocated_result_buffer); + + // Validates the given options and argument layouts and returns an appropriate + // error code. + tensorflow::Status ValidateExecuteOptions( + const ProgramShape& program_shape, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, + const ShapedBuffer* preallocated_result_buffer); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc new file mode 100644 index 0000000000..00e4b35d15 --- /dev/null +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/logical_buffer.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +string LogicalBuffer::ToString() const { + return tensorflow::strings::StrCat(instruction_->name(), "[", + tensorflow::str_util::Join(index_, ","), + "](#", id_, ")"); +} + +std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { + out << buffer.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h new file mode 100644 index 0000000000..21af9dcf66 --- /dev/null +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +struct HashLogicalBuffer; + +// Class describing a contiguous sequence of elements (ie, C array) which form +// the components of Shaped values in XLA. XLA arrays are trivially a +// single LogicalBuffer. Tuple values are made up of more than one +// LogicalBuffer: a LogicalBuffer for the pointers to elements, and a +// LogicalBuffer for each child element. +// +// Every buffer is defined by a particular instruction and most instructions +// define only a single buffer. Instructions which define a single buffer +// include array-shaped instructions such as Add but also includes Tuple-shaped +// instructions such as Tuple. The Tuple instruction defines a single buffer +// which is a vector of pointers to the buffers containing the Tuple +// instruction's operands. Though the result of the Tuple instruction includes +// multiple buffers only the top-level buffer (the vector of pointers) is +// defined by the Tuple instruction. The buffers containing the tuple elements +// are defined by earlier instructions, usually the operands of the Tuple +// instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one buffer. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. The tuple-shaped instructions do +// not assemble a tuple from existing buffers like the Tuple instruction does, +// but rather define the entire tuple. +// +// Some instructions, such as Bitcast, define no buffers. These instructions +// simply forward buffers from their operands. +// +// The LogicalBuffer object describes which HLO instruction defines a buffer and +// where within that instruction's output shape the buffer is defined. The +// location within the output shape is indicated by LogicalBuffer::index() which +// is defined identically to the index used in +// ShapeUtil::GetSubshape(). Examples: +// +// %add = Add(%foo, %bar) +// %tuple_constant = Constant({1, {42, 43}}) +// +// %add defines a single array-shaped buffer LogicalBuffer(%add, {}) which holds +// the array result of the add operation. The nested-tuple-shaped +// %tuple_constant defines 5 buffers described by the following LogicalBuffer +// objects: +// +// LogicalBuffer(%tuple_constant, {}) // "Top-level" buffer: vector of +// // pointers to LogicalBuffers at +// // indices {0} and {1} +// LogicalBuffer(%tuple_constant, {0}) // Holds value "1" +// LogicalBuffer(%tuple_constant, {1}) // Holds nested tuple: vector of +// // pointers to LogicalBuffers at +// // indices {1, 0} and {1, 1} +// LogicalBuffer(%tuple_constant, {1, 0}) // Holds value "42" +// LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43" +class LogicalBuffer { + public: + // Id is a unique identifier for the LogicalBuffer to facilitate efficient + // collections of LogicalBuffers with stable iteration order. + // LogicalBuffers are typically created and accessed through + // TuplePointsToAnalysis, and points-to analysis assigns each LogicalBuffer a + // unique value. + using Id = int64; + + LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) + : instruction_(instruction), index_(index), id_(id) {} + + Id id() const { return id_; } + + // Return the instruction that defines the buffer. + HloInstruction* instruction() const { return instruction_; } + + // Return the index within the output of the instruction where the buffer is + // defined. Index used defined as in ShapeUtil::GetSubshape() + const ShapeIndex& index() const { return index_; } + + // Return the shape of the buffer. This reference points into the shape field + // of the instruction defining the buffer. Therefore, the returned shape will + // contain the layout of instruction, if any. + const Shape& shape() const { + return ShapeUtil::GetSubshape(instruction_->shape(), index_); + } + + // Returns true if this buffer is the top-level output buffer of the defining + // HLO instruction. This is equivalent to index == {}. + bool IsTopLevel() const { return index_.empty(); } + + // Whether this buffer contains a tuple. + bool IsTuple() const { return ShapeUtil::IsTuple(shape()); } + + // operator< is required for std::set. + bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; } + + // Whether this buffer contains an array. + bool IsArray() const { return ShapeUtil::IsArray(shape()); } + + string ToString() const; + + private: + friend struct HashLogicalBuffer; + HloInstruction* instruction_; + ShapeIndex index_; + Id id_; + + // Similar to HLO constructs (HloInstruction, etc), pointers are used for + // comparison to equality, so disable all copying. + TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer); +}; + +struct HashLogicalBuffer { + size_t operator()(const LogicalBuffer& b) const { + std::hash hasher; + size_t h = hasher(b.instruction_); + for (int i = 0; i < b.index_.size(); i++) { + h += static_cast(b.index_[i] << i); + } + return h; + } +}; + +std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc new file mode 100644 index 0000000000..4014856b9b --- /dev/null +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/name_uniquer.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { + string root = prefix.empty() ? "name" : prefix.ToString(); + int* count = &(generated_names_[root]); + if (*count == 0) { + *count = 1; + return root; + } else { + tensorflow::strings::StrAppend(&root, separator_, *count); + (*count)++; + return root; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h new file mode 100644 index 0000000000..b0944adbc1 --- /dev/null +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Simple stateful class that helps generate "unique" names. To use it, simply +// call GetUniqueName as many times as needed. The names returned by +// GetUniqueName are guaranteed to be distinct for this instance of the class. +class NameUniquer { + public: + explicit NameUniquer(const string& separator = "__") + : separator_(separator) {} + + // Get a unique name in a string, with an optional prefix for convenience. + string GetUniqueName(tensorflow::StringPiece prefix = ""); + + private: + // The string to use to separate the prefix of the name from the uniquing + // integer value. + string separator_; + + // Map from name prefix to the number of names generated using that prefix + // so far. + std::unordered_map generated_names_; + + TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc new file mode 100644 index 0000000000..116bd3f067 --- /dev/null +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/platform_util.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +// Minimum supported CUDA compute capability is 3.5. +constexpr int kMinCudaComputeCapabilityMajor = 3; +constexpr int kMinCudaComputeCapabilityMinor = 5; + +/* static */ StatusOr> +PlatformUtil::GetSupportedPlatforms() { + se::MultiPlatformManager::PlatformMap platform_map; + se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms( + [&platform_map](se::MultiPlatformManager::PlatformMap* map) { + platform_map = *map; + return se::port::Status::OK(); + }); + if (platform_map.empty()) { + LOG(WARNING) << "no executor platforms available: platform map is empty"; + } + + // Gather all platforms which have an XLA compiler. + std::vector platforms; + for (auto& platform_pair : platform_map) { + auto* platform = platform_pair.second; + auto compiler_status = Compiler::GetForPlatform(platform); + if (compiler_status.ok()) { + if (platform->VisibleDeviceCount() > 0) { + LOG(INFO) << "platform " << platform->Name() << " present with " + << platform->VisibleDeviceCount() << " visible devices"; + } else { + LOG(WARNING) << "platform " << platform->Name() << " present but no " + << "visible devices found"; + } + // Note: currently we call zero device platforms "supported" on the basis + // that, if the platform support was linked in, it was probably intended + // to be used for execution, and this way we can flag an error. + // + // TODO(b/33730287) If we want an alternative version of this behavior we + // could add an --xla_fallback_to_host flag. + platforms.push_back(platform); + } else { + LOG(INFO) << "platform " << platform->Name() << " present but no " + << "XLA compiler available: " + << compiler_status.status().error_message(); + } + } + return platforms; +} + +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + // In the service we always link the cpu backend for ComputeConstant. So if + // one of the two platforms is CPU then pick the other (non-cpu) platform as + // the default. + if (platforms[0]->id() == se::host::kHostPlatformId) { + return platforms[1]; + } else if (platforms[1]->id() == se::host::kHostPlatformId) { + return platforms[0]; + } + } + + // Multiple platforms present and we can't pick a reasonable default. + auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; + string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + return InvalidArgument( + "must specify platform because more than one platform found: %s", + platforms_string.c_str()); +} + +// Returns whether the device underlying the given StreamExecutor is supported +// by XLA. +static bool IsDeviceSupported(se::StreamExecutor* executor) { + const auto& description = executor->GetDeviceDescription(); + if (executor->platform()->id() == se::cuda::kCudaPlatformId) { + // CUDA devices must have a minimum compute capability. + int major_version, minor_version; + if (description.cuda_compute_capability(&major_version, &minor_version)) { + if (major_version < kMinCudaComputeCapabilityMajor || + (major_version == kMinCudaComputeCapabilityMajor && + minor_version < kMinCudaComputeCapabilityMinor)) { + LOG(INFO) << "StreamExecutor cuda device (" + << executor->device_ordinal() << ") is of " + << "insufficient compute capability: " + << kMinCudaComputeCapabilityMajor << "." + << kMinCudaComputeCapabilityMinor << " required, " + << "device is " << major_version << "." << minor_version; + return false; + } + } + } + return true; +} + +/* static */ StatusOr> +PlatformUtil::GetStreamExecutors(se::Platform* platform) { + int device_count = platform->VisibleDeviceCount(); + if (device_count <= 0) { + return NotFound("no %s devices found", platform->Name().c_str()); + } + if (platform->id() == se::host::kHostPlatformId) { + // On host "devices", StreamExecutor exports a device for each hardware + // thread. Because we parallelize a single computation across threads, it + // doesn't make sense to expose these as separate devices, so fix the number + // of devices to one. + device_count = 1; + } + std::vector stream_executors(device_count, nullptr); + for (int i = 0; i < device_count; ++i) { + se::StreamExecutorConfig config; + config.ordinal = i; + auto executor_status = platform->GetExecutor(config); + if (executor_status.ok()) { + se::StreamExecutor* executor = executor_status.ValueOrDie(); + if (IsDeviceSupported(executor)) { + stream_executors[i] = executor; + } + } else { + LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name() + << ":" << i << ": " + << executor_status.status().error_message(); + } + } + if (std::all_of(stream_executors.begin(), stream_executors.end(), + [](se::StreamExecutor* s) { return s == nullptr; })) { + return InternalError("no supported devices found for platform %s", + platform->Name().c_str()); + } + return stream_executors; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h new file mode 100644 index 0000000000..fe0281a69a --- /dev/null +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Utilities for querying platforms and devices used by XLA. +class PlatformUtil { + public: + // Returns the platforms present on the system and supported by XLA. + // + // Note that, even if a platform is present with zero devices, if we *do* have + // compilation support for it, it will be returned in this sequence. + static StatusOr> + GetSupportedPlatforms(); + + // Convenience function which returns the default supported platform. If + // exactly one supported platform is present, then this platform is the + // default platform. If exactly two supported platforms are present and one + // platform is CPU (host) then the non-CPU platform is default. This logic is + // used because the XLA service always links in the CPU backend to run + // ComputeConstant, so if exactly one other platform is linked in, we assume + // the intent is to execute on that non-CPU platform. If none of these + // conditions are met the function returns an error. + static StatusOr GetDefaultPlatform(); + + // Returns a vector of StreamExecutors for the given platform. The vector is + // indexed by device ordinal (device numbering used by StreamExecutor). If an + // element is nullptr, then the device is present by not supported by XLA. + // + // If the platform has no visible devices, a not-found error is returned. + static StatusOr> + GetStreamExecutors(perftools::gputools::Platform* platform); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc new file mode 100644 index 0000000000..5625804c2e --- /dev/null +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reshape_mover.h" + +#include +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +namespace { + +// Returns whether `a` and `b` are equivalent for the purposes of this pass. +bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { + if (a->opcode() != b->opcode() || + !ShapeUtil::SameDimensions(a->shape(), b->shape())) { + return false; + } + switch (a->opcode()) { + case HloOpcode::kTranspose: + return a->dimensions() == b->dimensions(); + case HloOpcode::kReshape: + return ShapeUtil::SameDimensions(a->operand(0)->shape(), + b->operand(0)->shape()); + default: + return false; + } +} + +bool IsElementwiseOfEquivalentReshapesOrTransposes( + const HloInstruction* instruction) { + const std::vector& operands = instruction->operands(); + return instruction->IsElementwise() && instruction->operand_count() > 0 && + std::all_of(operands.begin(), operands.end(), + [](const HloInstruction* instruction) { + // We require operand have no other users as otherwise + // this is not a clear win. + return 1 == instruction->users().size(); + }) && + // Check whether each operand beyond the first is equivalent to the + // first. + std::all_of(operands.begin(), operands.end(), + [&operands](const HloInstruction* operand) { + return AreEquivalentReshapes(operands[0], operand); + }); +} + +// Try to sink any reshape or transpose operands of `instruction` across it. We +// do so if `instruction` is elementwise and all operands are equivalent +// reshapes or transposes. +bool TrySinkReshapeOrTranspose(HloComputation* computation, + HloInstruction* instruction) { + if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + std::vector operands = instruction->operands(); + auto old_reshape = operands[0]; + for (size_t i = 0; i < operands.size(); ++i) { + operands[i] = operands[i]->mutable_operand(0); + } + auto new_elementwise = + computation->AddInstruction(instruction->CloneWithNewOperands( + // `instruction` may change the element type, e.g., from + // operands[0] -> reshape -> convert (`instruction`) + // to + // operands[0] -> convert' -> reshape' + // + // In this case, convert' should have the same element type as + // `convert` and the same dimensions as operands[0]. + ShapeUtil::MakeShape( + instruction->shape().element_type(), + AsInt64Slice(operands[0]->shape().dimensions())), + operands)); + std::unique_ptr new_reshape; + switch (old_reshape->opcode()) { + case HloOpcode::kReshape: + new_reshape = HloInstruction::CreateReshape(instruction->shape(), + new_elementwise); + break; + case HloOpcode::kTranspose: + new_reshape = HloInstruction::CreateTranspose( + instruction->shape(), new_elementwise, old_reshape->dimensions()); + break; + default: + LOG(FATAL) << "Bad opcode"; + } + computation->ReplaceWithNewInstruction(instruction, std::move(new_reshape)); + return true; + } + return false; +} + +} // namespace + +StatusOr ReshapeMover::Run(HloModule* module) { + return std::any_of( + module->computations().begin(), module->computations().end(), + [](const std::unique_ptr& computation) { + std::list postorder = + computation->MakeInstructionPostOrder(); + return std::any_of(postorder.begin(), postorder.end(), + [&computation](HloInstruction* instruction) { + return TrySinkReshapeOrTranspose(computation.get(), + instruction); + }); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h new file mode 100644 index 0000000000..f7146b0ee3 --- /dev/null +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// A pass which moves Reshapes and Transposes to let later passes combine them. +// This now only moves them outputward across elementwise ops all whose operands +// are equivalent Reshapes or Transposes, but in future could potentially move +// them inputward also. +class ReshapeMover : public HloPass { + public: + ReshapeMover() : HloPass("reshape motion") {} + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc new file mode 100644 index 0000000000..850295c726 --- /dev/null +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reshape_mover.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { +using ReshapeMoverTest = HloTestBase; + +TEST_F(ReshapeMoverTest, ReshapesWithNonSameInputShapesNotMoved) { + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param0")); + auto reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape3 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape2, reshape3)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(add4, computation->root_instruction()); + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_EQ(add4, computation->root_instruction()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc new file mode 100644 index 0000000000..847aea7888 --- /dev/null +++ b/tensorflow/compiler/xla/service/service.cc @@ -0,0 +1,1428 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/service.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +namespace { + +// Copies the contents of an Allocation into a Literal proto. +tensorflow::Status LiteralFromAllocation(const Allocation* allocation, + const Shape& literal_shape, + Literal* literal) { + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + allocation->backend()->stream_executor(allocation->device_ordinal())); + return allocation->backend()->transfer_manager()->TransferLiteralFromDevice( + executor, allocation->device_memory(), allocation->shape(), literal_shape, + literal); +} + +// Records the arguments used to invoke a computation in a SessionModule +// proto. +tensorflow::Status RecordArguments( + const tensorflow::gtl::ArraySlice arg_allocations, + SessionModule* module) { + module->clear_arguments(); + for (const Allocation* allocation : arg_allocations) { + TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(), + module->add_arguments())); + } + return tensorflow::Status::OK(); +} + +// Records the result of a computation in a SessionModule proto. +tensorflow::Status RecordResult(const Allocation* result_allocation, + SessionModule* module) { + module->clear_result(); + return LiteralFromAllocation(result_allocation, result_allocation->shape(), + module->mutable_result()); +} + +} // namespace + +ServiceOptions& ServiceOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* ServiceOptions::platform() const { + return platform_; +} + +ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { + number_of_replicas_ = number_of_replicas; + return *this; +} + +int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } + +/* static */ StatusOr> Service::NewService( + perftools::gputools::Platform* platform) { + ServiceOptions default_options; + default_options.set_platform(platform); + return NewService(default_options); +} + +/* static */ StatusOr> Service::NewService( + const ServiceOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + std::unique_ptr execute_backend; + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + TF_ASSIGN_OR_RETURN( + execute_backend, + Backend::CreateBackend(platform, options.number_of_replicas())); + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, + CreateComputeConstantBackend()); + std::unique_ptr service(new Service( + std::move(execute_backend), std::move(compute_constant_backend))); + return std::move(service); +} + +/* static */ StatusOr> +Service::CreateComputeConstantBackend() { + TF_ASSIGN_OR_RETURN(std::vector platforms, + PlatformUtil::GetSupportedPlatforms()); + for (auto* platform : platforms) { + if (platform->id() == se::host::kHostPlatformId) { + return Backend::CreateBackend(platform, /*replica_count=*/1); + } + } + return NotFound("CPU platform not found"); +} + +/* static */ void Service::DumpExecutedHlo(const HloModule& module, + const string& label, + const HloExecutionProfile* profile) { + VLOG(2) << "module name = " << module.name(); + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + if (!flags->xla_generate_hlo_graph.empty() && + RE2::PartialMatch(module.name(), flags->xla_generate_hlo_graph)) { + hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, + flags->xla_hlo_graph_addresses, + flags->xla_hlo_graph_layout, profile); + } + if (!flags->xla_log_hlo_text.empty() && + RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) { + LOG(INFO) << "HLO for module " << module.name(); + LOG(INFO) << "Label: " << label; + XLA_LOG_LINES(2, module.ToString()); + } + if (!flags->xla_dump_hlo_text_to.empty()) { + hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to); + } +} + +/* static */ Compiler::HloDumper Service::MakeHloDumper() { + return [](const HloModule& module, const string& label) { + return DumpExecutedHlo(module, label, /*profile=*/nullptr); + }; +} + +Service::Service(std::unique_ptr execute_backend, + std::unique_ptr compute_constant_backend) + : execute_backend_(std::move(execute_backend)), + compute_constant_backend_(std::move(compute_constant_backend)) { + LOG(INFO) << "XLA service executing computations on platform " + << execute_backend_->platform()->Name() << ". Devices:"; + for (int i = 0; i < execute_backend_->device_count(); ++i) { + if (execute_backend_->device_ordinal_supported(i)) { + se::StreamExecutor* executor = + execute_backend_->stream_executor(i).ValueOrDie(); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << tensorflow::strings::Printf( + " StreamExecutor device (%d): %s, %s", i, description.name().c_str(), + description.platform_version().c_str()); + } else { + LOG(INFO) << tensorflow::strings::Printf( + " StreamExecutor device (%d) not supported", i); + } + } +} + +tensorflow::Status Service::Computation(const ComputationRequest* arg, + ComputationResponse* result) { + if (arg->name().empty()) { + return InvalidArgument("computation request needs a name"); + } + + *result->mutable_computation() = + computation_tracker_.NewComputation(arg->name()); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::CreateChannelHandle( + const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { + *result->mutable_channel() = channel_tracker_.NewChannel(); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { + return allocation_tracker_.Unregister(arg->data()); +} + +// Deconstructs a previously-allocated global handle. +tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { + TF_ASSIGN_OR_RETURN( + std::vector elements, + allocation_tracker_.DeconstructTuple(arg->tuple_handle())); + + for (auto& element : elements) { + *result->add_element_handles() = element; + } + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::ValidateResultShapeWithLayout( + const Shape& shape_with_layout, const Shape& result_shape) const { + if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { + return InvalidArgument( + "Shape used to set computation result layout %s is not compatible " + "with result shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanString(result_shape).c_str()); + } + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument( + "Shape used to set computation result layout %s does not have layout", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + } + return ShapeUtil::ValidateShape(shape_with_layout); +} + +StatusOr> Service::ResolveAndValidateArguments( + tensorflow::gtl::ArraySlice arguments, + const Backend* backend, int device_ordinal) { + std::vector allocations; + for (int i = 0; i < arguments.size(); ++i) { + auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); + if (!allocation_status.ok()) { + return Status(allocation_status.status().code(), + tensorflow::strings::StrCat( + allocation_status.status().error_message(), ", ", + "failed to resolve allocation for parameter ", i)); + } + const Allocation* allocation = allocation_status.ValueOrDie(); + + // Verify allocation is same platform and device as the execution. + if (allocation->backend() != backend || + allocation->device_ordinal() != device_ordinal) { + return InvalidArgument( + "argument %d is on device %s but computation will be executed " + "on device %s", + i, + allocation->backend() + ->device_name(allocation->device_ordinal()) + .c_str(), + backend->device_name(device_ordinal).c_str()); + } + + allocations.push_back(allocation); + } + return allocations; +} + +StatusOr> Service::CreateModuleConfig( + const ProgramShape& program_shape, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout, uint64 seed) { + auto module_config = MakeUnique(program_shape); + auto* computation_layout = module_config->mutable_entry_computation_layout(); + + if (program_shape.parameters_size() != arguments.size()) { + return InvalidArgument("computation takes %d parameters, but %zu given", + program_shape.parameters_size(), arguments.size()); + } + + for (int i = 0; i < arguments.size(); ++i) { + // Verify that shape of arguments matches the shape of the arguments in the + // ProgramShape. + if (!ShapeUtil::Compatible(arguments[i]->shape(), + program_shape.parameters(i))) { + return InvalidArgument( + "computation expects parameter %d to have shape %s, given shape %s", + i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + arguments[i]->shape())); + } + if (shape_with_output_layout == nullptr) { + computation_layout->mutable_result_layout()->Clear(); + } else { + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(*shape_with_output_layout, + program_shape.result())); + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *shape_with_output_layout)); + } + + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + if (flags->xla_hlo_profile) { + module_config->enable_hlo_profiling(true); + } + + module_config->set_seed(seed); + module_config->set_replica_count(execute_backend_->Replicas().size()); + + return std::move(module_config); +} + +StatusOr>> Service::BuildExecutables( + std::vector versioned_handles, + std::vector> module_configs, + Backend* backend, + std::vector executors) { + // Dump computation proto state if flag is set. + std::vector> session_modules; + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + const string& other_directory_path = flags->xla_dump_executions_to; + if ((!directory_path.empty() || !other_directory_path.empty())) { + for (int64 i = 0; i < versioned_handles.size(); ++i) { + TF_ASSIGN_OR_RETURN(std::unique_ptr session_module, + computation_tracker_.SnapshotComputation( + versioned_handles[i].handle)); + if (!directory_path.empty()) { + string filename = + tensorflow::strings::Printf("computation_%lld__%s__version_%lld", + versioned_handles[i].handle.handle(), + session_module->entry().name().c_str(), + versioned_handles[i].version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + session_modules.push_back(std::move(session_module)); + } + } + } + + VLOG(1) << "building executables from:"; + for (const VersionedComputationHandle& versioned_handle : versioned_handles) { + VLOG(1) << versioned_handle.handle.handle() << "@v" + << versioned_handle.version; + } + + std::vector> modules; + for (const VersionedComputationHandle& versioned_handle : versioned_handles) { + TF_ASSIGN_OR_RETURN(auto module, + computation_tracker_.BuildHloModule( + versioned_handle, + /*include_unused_parameters=*/true)); + modules.push_back(std::move(module)); + } + + Compiler::HloDumper hlo_dumper = MakeHloDumper(); + TF_ASSIGN_OR_RETURN(std::vector> executables, + backend->compiler()->Compile( + std::move(modules), std::move(module_configs), + hlo_dumper, std::move(executors))); + + if (!other_directory_path.empty()) { + for (int64 i = 0; i < versioned_handles.size(); ++i) { + executables[i]->set_session_module(std::move(session_modules[i])); + } + } + + return std::move(executables); +} + +StatusOr> Service::BuildExecutable( + const VersionedComputationHandle& versioned_handle, + std::unique_ptr module_config, + bool executable_for_compute_constant, + const tensorflow::gtl::ArraySlice + arguments, + Backend* backend, se::StreamExecutor* executor) { + // Dump computation proto state if flag is set. + std::unique_ptr session_module; + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + const string& other_directory_path = flags->xla_dump_executions_to; + if (!executable_for_compute_constant && + (!directory_path.empty() || !other_directory_path.empty())) { + TF_ASSIGN_OR_RETURN( + session_module, + computation_tracker_.SnapshotComputation(versioned_handle.handle)); + if (!directory_path.empty()) { + string filename = tensorflow::strings::Printf( + "computation_%lld__%s__version_%lld", + versioned_handle.handle.handle(), + session_module->entry().name().c_str(), versioned_handle.version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + } + } + + VLOG(1) << tensorflow::strings::Printf("building executable %lld@v%lld", + versioned_handle.handle.handle(), + versioned_handle.version); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + computation_tracker_.BuildHloModule( + versioned_handle, + /*include_unused_parameters=*/!executable_for_compute_constant)); + + Compiler::HloDumper hlo_dumper = MakeHloDumper(); + if (executable_for_compute_constant && + !flags->xla_hlo_graph_for_compute_constant) { + hlo_dumper = [](const HloModule&, const string&) {}; + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend->compiler()->Compile(std::move(module), std::move(module_config), + hlo_dumper, executor)); + + if (!other_directory_path.empty()) { + executable->set_session_module(std::move(session_module)); + } + + return std::move(executable); +} + +StatusOr> Service::BuildAndCacheExecutable( + const VersionedComputationHandle& versioned_handle, + std::unique_ptr module_config, + const tensorflow::gtl::ArraySlice + arguments, + Backend* backend, perftools::gputools::StreamExecutor* executor, + ExecutionProfile* profile) { + std::shared_ptr executable = + compilation_cache_.LookUp(versioned_handle, *module_config); + + if (executable != nullptr) { + // Executable found in the computation cache. + if (profile != nullptr) { + profile->set_compilation_cache_hit(true); + } + return executable; + } + + uint64 start_micros = + // Avoid reading the clock if we don't want timing info + (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; + + // Take a copy of the module config, as compilation introduces layouts where + // layouts were optional before. + HloModuleConfig original_module_config = *module_config; + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable_unique_ptr, + BuildExecutable(versioned_handle, std::move(module_config), + /*executable_for_compute_constant=*/false, arguments, + execute_backend_.get(), executor)); + + if (profile != nullptr) { + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + uint64 milliseconds = (end_micros - start_micros) / 1000; + profile->set_compilation_cache_hit(false); + profile->set_compile_time_ms(milliseconds); + } + + // Insert executable into the cache. + return compilation_cache_.Insert(std::move(executable_unique_ptr), + original_module_config); +} + +StatusOr> +Service::ExecuteParallelAndRegisterResult( + tensorflow::gtl::ArraySlice executables, + tensorflow::gtl::ArraySlice< + std::vector> + arguments, + Backend* backend, + tensorflow::gtl::ArraySlice executors, + tensorflow::gtl::ArraySlice result_tags) { + // TODO(b/33943292): Support for replication when using multiple computations. + TF_RET_CHECK(backend->Replicas().size() == 1); + + // Set up streams. + std::vector> streams; + + auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() { + for (std::unique_ptr& stream : streams) { + backend->ReleaseStream(std::move(stream)); + } + }); + + for (se::StreamExecutor* executor : executors) { + TF_ASSIGN_OR_RETURN(std::unique_ptr stream, + backend->AcquireStream(executor)); + // Push back after so that the releaser only sees real streams. + streams.push_back(std::move(stream)); + } + + // Set up run options. + std::vector run_options; + for (const std::unique_ptr& stream : streams) { + run_options.emplace_back(); + auto& options = run_options.back(); + options.set_stream(stream.get()); + options.set_allocator(backend->memory_allocator()); + options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + } + + // Asynchronously launch all executables. + std::vector result_handles; + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase result, + executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); + result_handles.push_back(allocation_tracker_.Register( + backend, executors[i]->device_ordinal(), result, + executables[i]->result_shape(), result_tags[i])); + } + + // Wait for all executions to complete. + for (int64 i = 0; i < result_handles.size(); ++i) { + if (!streams[i]->BlockHostUntilDone()) { + return InternalError("failed to complete execution for stream %lld", i); + } + } + + return result_handles; +} + +StatusOr Service::ExecuteAndRegisterResult( + Executable* executable, + const tensorflow::gtl::ArraySlice + arguments, + Backend* backend, perftools::gputools::StreamExecutor* executor, + const string& result_tag, ExecutionProfile* profile) { + TF_RET_CHECK(!backend->Replicas().empty()); + + // Set up streams. + std::vector> streams; + + auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() { + for (std::unique_ptr& stream : streams) { + backend->ReleaseStream(std::move(stream)); + } + }); + + for (se::StreamExecutor* executor : backend->Replicas()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr stream, + backend->AcquireStream(executor)); + // Push back after so that the releaser only sees real streams. + streams.push_back(std::move(stream)); + } + + // Set up run options. + std::vector run_options; + for (const std::unique_ptr& stream : streams) { + run_options.emplace_back(); + auto& options = run_options.back(); + options.set_stream(stream.get()); + options.set_allocator(backend->memory_allocator()); + options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + } + + perftools::gputools::DeviceMemoryBase result; + if (backend->Replicas().size() == 1) { + TF_ASSIGN_OR_RETURN( + result, ExecuteOnStreamWrapper>( + executable, &run_options[0], profile, + [&arguments](Executable* executable, + const ExecutableRunOptions* run_options, + HloExecutionProfile* hlo_execution_profile) { + return executable->ExecuteOnStream(run_options, arguments, + hlo_execution_profile); + })); + } else { + std::vector< + tensorflow::gtl::ArraySlice> + repeated_arguments(backend->Replicas().size(), arguments); + + TF_ASSIGN_OR_RETURN( + auto results, + executable->ExecuteOnStreams(run_options, repeated_arguments)); + TF_RET_CHECK(!results.empty()); + result = results[0]; + } + return allocation_tracker_.Register(backend, executor->device_ordinal(), + result, executable->result_shape(), + result_tag); +} + +tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) { + TF_ASSIGN_OR_RETURN(UserComputation * computation, + computation_tracker_.Resolve(arg->computation())); + return computation->SetReturnValue(arg->operand()); +} + +tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) { + VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); + + std::vector> all_arguments; + std::vector executors; + std::vector versioned_handles; + std::vector> module_configs; + std::vector computation_names; + + if (arg->requests_size() > execute_backend_->stream_executors().size()) { + return FailedPrecondition( + "there are not enough stream executors to execute %d computations", + arg->requests_size()); + } + + for (int64 i = 0; i < arg->requests_size(); ++i) { + // Get the stream executor on which the computation will run. Select the + // specific device if requested, otherwise select the i'th device from the + // list of available stream executors. + se::StreamExecutor* executor; + if (arg->requests(i).has_device_handle()) { + executor = + execute_backend_ + ->stream_executors()[arg->requests(i).device_handle().handle()]; + } else { + executor = execute_backend_->stream_executors()[i]; + } + CHECK(executor != nullptr); + + // Resolve the UserComputation object associated with the requested + // computation and compute the program shape. + const ExecuteRequest& request = arg->requests(i); + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(request.computation())); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + if (user_computation->request_count(versioned_handle.version) == 0) { + return InvalidArgument("computations may not be empty"); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + const Shape* shape_with_output_layout = + request.has_shape_with_output_layout() + ? &request.shape_with_output_layout() + : nullptr; + + // Resolve the allocations for the arguments of the computation, and create + // a vector of device memory offsets for the arguments from the allocations. + TF_ASSIGN_OR_RETURN( + std::vector arg_allocations, + ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), + executor->device_ordinal())); + std::vector arguments; + for (const Allocation* allocation : arg_allocations) { + arguments.push_back(allocation->device_memory()); + } + + // Create an HloModuleConfig object for the computation, given the shape of + // the program and the argument allocations. + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + shape_with_output_layout, request.seed())); + VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + // Adds to the vectors to build and execute the computations after the loop. + all_arguments.push_back(arguments); + versioned_handles.push_back(versioned_handle); + module_configs.push_back(std::move(module_config)); + computation_names.push_back(user_computation->name()); + executors.push_back(executor); + } + + // Build the user computations into HloModules and compile to generate the + // executables. + TF_ASSIGN_OR_RETURN( + std::vector> executables, + BuildExecutables(versioned_handles, std::move(module_configs), + execute_backend_.get(), executors)); + std::vector executable_ptrs; + for (const auto& executable : executables) { + executable_ptrs.push_back(executable.get()); + } + + // Execute the generated executables in parallel and return the device + // handles for each computation's output. + TF_ASSIGN_OR_RETURN( + std::vector outputs, + ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, + execute_backend_.get(), executors, + computation_names)); + for (const GlobalDataHandle& output : outputs) { + ExecuteResponse response; + *response.mutable_output() = output; + *result->add_responses() = response; + } + + VLOG(1) << "successfully completed 'execute-parallel' request"; + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { + const int64 available_device_count = + execute_backend_->stream_executors().size(); + const int64 replicas = execute_backend_->Replicas().size(); + if (available_device_count < arg->device_count() * replicas) { + return ResourceExhausted( + "Requested device count (%lld) exceeds the number of available devices " + "on the target (%lld)", + arg->device_count(), available_device_count); + } + + for (int64 i = 0; i < arg->device_count(); ++i) { + DeviceHandle device_handle; + device_handle.set_handle( + execute_backend_->stream_executors()[i * replicas]->device_ordinal()); + *result->add_device_handles() = device_handle; + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::Execute(const ExecuteRequest* arg, + ExecuteResponse* result) { + VLOG(1) << "running execute request: " << arg->ShortDebugString(); + + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(arg->computation())); + + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + if (user_computation->request_count(versioned_handle.version) == 0) { + return InvalidArgument("computations may not be empty"); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + TF_ASSIGN_OR_RETURN( + std::vector arg_allocations, + ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + execute_backend_->default_device_ordinal())); + + const Shape* shape_with_output_layout = arg->has_shape_with_output_layout() + ? &arg->shape_with_output_layout() + : nullptr; + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + shape_with_output_layout, arg->seed())); + + VLOG(3) << "Execute created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + std::vector arguments; + for (const Allocation* allocation : arg_allocations) { + arguments.push_back(allocation->device_memory()); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr executable, + BuildAndCacheExecutable(versioned_handle, std::move(module_config), + arguments, execute_backend_.get(), + execute_backend_->default_stream_executor(), + result->mutable_profile())); + + if (executable->dumping()) { + executable->session_module()->set_execution_platform( + execute_backend_->platform()->Name()); + TF_RETURN_IF_ERROR( + RecordArguments(arg_allocations, executable->session_module())); + } + + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), arguments, execute_backend_.get(), + execute_backend_->default_stream_executor(), + "result of " + user_computation->name(), result->mutable_profile())); + + if (executable->dumping()) { + TF_ASSIGN_OR_RETURN(const Allocation* result_allocation, + allocation_tracker_.Resolve(result->output())); + TF_RETURN_IF_ERROR( + RecordResult(result_allocation, executable->session_module())); + TF_RETURN_IF_ERROR(executable->DumpSessionModule()); + } + + VLOG(1) << "successfully completed 'execute' request"; + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) { + VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); + + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(arg->computation())); + + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + if (user_computation->request_count(versioned_handle.version) == 0) { + return InvalidArgument("computations may not be empty"); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + TF_ASSIGN_OR_RETURN( + std::vector arg_allocations, + ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + execute_backend_->default_device_ordinal())); + + const Shape* shape_with_output_layout = arg->has_shape_with_output_layout() + ? &arg->shape_with_output_layout() + : nullptr; + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + shape_with_output_layout, arg->seed())); + + VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + std::vector arguments; + for (const Allocation* allocation : arg_allocations) { + arguments.push_back(allocation->device_memory()); + } + + ExecutionProfile profile; + + TF_ASSIGN_OR_RETURN( + std::shared_ptr executable, + BuildAndCacheExecutable(versioned_handle, std::move(module_config), + arguments, execute_backend_.get(), + execute_backend_->default_stream_executor(), + &profile)); + + TF_RET_CHECK(!execute_backend_->Replicas().empty()); + // Set up streams. + std::vector> streams; + + auto stream_releaser = ::tensorflow::gtl::MakeCleanup([this, &streams]() { + for (std::unique_ptr& stream : streams) { + execute_backend_->ReleaseStream(std::move(stream)); + } + }); + + for (se::StreamExecutor* executor : execute_backend_->Replicas()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr stream, + execute_backend_->AcquireStream(executor)); + // Push back after so that the releaser only sees real streams. + streams.push_back(std::move(stream)); + } + + perftools::gputools::DeviceMemoryBase result_data; + for (const std::unique_ptr& stream : streams) { + ExecutableRunOptions options; + options.set_stream(stream.get()); + options.set_allocator(execute_backend_->memory_allocator()); + options.set_inter_op_thread_pool(execute_backend_->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + execute_backend_->eigen_intra_op_thread_pool_device()); + + TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase this_result_data, + executable->ExecuteAsyncOnStream(&options, arguments)); + + // Take the first result. + if (result_data == nullptr) { + result_data = this_result_data; + } + } + + auto output = allocation_tracker_.Register( + execute_backend_.get(), execute_backend_->default_device_ordinal(), + result_data, executable->result_shape(), + "result of " + user_computation->name()); + + *result->mutable_execution() = execution_tracker_.Register( + execute_backend_.get(), std::move(streams), profile, output); + streams.clear(); + + VLOG(1) << "successfully completed 'execute-async' request"; + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { + TF_ASSIGN_OR_RETURN(const auto execution, + execution_tracker_.Resolve(arg->execution())); + + TF_RETURN_IF_ERROR(execution->BlockUntilDone()); + + *result->mutable_output() = execution->result(); + *result->mutable_profile() = execution->profile(); + + TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); + VLOG(1) << "successfully completed 'wait-for-execution' request"; + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { + TF_ASSIGN_OR_RETURN(const Allocation* allocation, + allocation_tracker_.Resolve(arg->data())); + + const Shape* literal_shape; + if (arg->has_shape_with_layout()) { + if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { + return InvalidArgument("shape_with_layout must have layout if present."); + } + literal_shape = &arg->shape_with_layout(); + } else { + literal_shape = &allocation->shape(); + } + + return LiteralFromAllocation(allocation, *literal_shape, + result->mutable_literal()); +} + +tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { + const Literal& literal = arg->literal(); + const Shape& shape = literal.shape(); + + if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { + // TODO(b/32990684): Tuple transfers to host end up allocating further + // buffers - implement that correctly. + return Unimplemented( + "Tuple transfers to the device not supported with replication."); + } + + se::StreamExecutor* stream_executor; + if (arg->has_device_handle()) { + TF_ASSIGN_OR_RETURN( + stream_executor, + execute_backend_->stream_executor(arg->device_handle().handle())); + } else { + stream_executor = execute_backend_->default_stream_executor(); + } + + // Allocate memory on the device, using the stream executor. The size of the + // allocation is obtained by examining the shape of the literal passed from + // the client. An allocation handle is returned in the response. + int64 allocation_size = + execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, + execute_backend_->memory_allocator()->Allocate( + stream_executor->device_ordinal(), allocation_size)); + + *result->mutable_data() = allocation_tracker_.Register( + execute_backend_.get(), stream_executor->device_ordinal(), allocation, + shape, tensorflow::strings::StrCat("TransferToServer literal of size ", + allocation_size)); + + TF_ASSIGN_OR_RETURN( + auto replicas, + execute_backend_->Replicas(stream_executor->device_ordinal())); + for (se::StreamExecutor* executor : replicas) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, &allocation)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { + const int64 replica_count = execute_backend_->Replicas().size(); + if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + return FailedPrecondition( + "%s", + tensorflow::strings::StrCat( + "The replica_id=", arg->replica_id(), + " on TransferToInfeedRequest not in range [0, replica_count=", + replica_count, ").") + .c_str()); + } + + se::StreamExecutor* executor; + if (arg->has_device_handle()) { + TF_ASSIGN_OR_RETURN( + auto replicas, + execute_backend_->Replicas(arg->device_handle().handle())); + executor = replicas[arg->replica_id()]; + } else { + executor = execute_backend_->Replicas()[arg->replica_id()]; + } + + return execute_backend_->transfer_manager()->TransferLiteralToInfeed( + executor, arg->literal()); +} + +tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { + int first_device_ordinal = arg->has_device_handle() + ? arg->device_handle().handle() + : execute_backend_->default_device_ordinal(); + TF_ASSIGN_OR_RETURN(auto executors, + execute_backend_->Replicas(first_device_ordinal)); + for (se::StreamExecutor* executor : executors) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->ResetDevice(executor)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::TransferToClientInProcess( + const TransferToClientInProcessRequest* arg, + TransferToClientInProcessResponse* result) { + TF_RETURN_IF_ERROR(CheckRunsInClientProcess("TransferToClientInProcess")); + + TF_ASSIGN_OR_RETURN(const Allocation* allocation, + allocation_tracker_.Resolve(arg->data())); + + void* buffer = reinterpret_cast(arg->buffer()); + int64 size = ShapeUtil::ByteSizeOf(allocation->shape()); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + allocation->backend()->stream_executor(allocation->device_ordinal())); + + return allocation->backend()->transfer_manager()->TransferBufferFromDevice( + executor, allocation->device_memory(), size, buffer); +} + +tensorflow::Status Service::TransferToServerInProcess( + const TransferToServerInProcessRequest* arg, + TransferToServerInProcessResponse* result) { + TF_RETURN_IF_ERROR(CheckRunsInClientProcess("TransferToServerInProcess")); + + const Shape& shape = arg->shape(); + + if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { + // TODO(b/32990684): Tuple transfers to host end up allocating further + // buffers - implement that correctly. + return Unimplemented( + "Tuple transfers to the device not supported with replication."); + } + + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("shape must have layout"); + } + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + + const void* buffer = reinterpret_cast(arg->buffer()); + + // Allocate memory on the device, using the stream executor. The size of the + // allocation is obtained by examining the shape of the literal passed from + // the client. An allocation handle is returned in the response. + int64 allocation_size = + execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); + se::StreamExecutor* stream_executor = + execute_backend_->default_stream_executor(); + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, + execute_backend_->memory_allocator()->Allocate( + stream_executor->device_ordinal(), allocation_size)); + + *result->mutable_data() = allocation_tracker_.Register( + execute_backend_.get(), stream_executor->device_ordinal(), allocation, + shape, tensorflow::strings::StrCat("TransferToServer literal of size ", + allocation_size)); + + for (se::StreamExecutor* executor : execute_backend_->Replicas()) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferBufferToDevice( + executor, allocation_size, buffer, &allocation)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(arg->computation())); + + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandleAtOperation(arg->operand()); + + if (user_computation->request_count(versioned_handle.version) == 0) { + return InvalidArgument("computations may not be empty"); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + user_computation->IsConstant(arg->operand())); + + result->set_is_constant(is_constant); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(arg->computation())); + + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandleAtOperation(arg->operand()); + + if (user_computation->request_count(versioned_handle.version) == 0) { + return InvalidArgument("computations may not be empty"); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + user_computation->IsConstant(arg->operand())); + + if (!is_constant) { + return InvalidArgument("Operand to ComputeConstant depends on parameter."); + } + + // We can't use ComputeProgramShape because it checks that all parameter + // instructions are present and contiguous. Instead construct ProgramShape + // directly. + ProgramShape program_shape; + TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), + user_computation->GetShape(arg->operand())); + + TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + + Shape shape_with_output_layout(program_shape.result()); + if (arg->has_output_layout()) { + TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( + arg->output_layout(), shape_with_output_layout)); + *shape_with_output_layout.mutable_layout() = arg->output_layout(); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig( + program_shape, {}, + arg->has_output_layout() ? &shape_with_output_layout : nullptr, + /*seed=*/0)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr executable, + BuildExecutable(versioned_handle, std::move(module_config), + /*executable_for_compute_constant=*/true, + /*arguments=*/{}, compute_constant_backend_.get(), + compute_constant_backend_->default_stream_executor())); + + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), /*arguments=*/{}, compute_constant_backend_.get(), + compute_constant_backend_->default_stream_executor(), + "constant computed from " + user_computation->name(), + /*profile=*/nullptr)); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) { + TF_ASSIGN_OR_RETURN(const Allocation* allocation, + allocation_tracker_.Resolve(arg->data())); + *result->mutable_shape() = allocation->shape(); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::GetComputationShape( + const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) { + TF_ASSIGN_OR_RETURN(UserComputation * computation, + computation_tracker_.Resolve(arg->computation())); + + VersionedComputationHandle versioned_handle = + computation->GetVersionedHandle(); + + TF_ASSIGN_OR_RETURN( + auto program_shape, + computation->ComputeProgramShape(versioned_handle.version)); + *result->mutable_program_shape() = *program_shape; + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) { + TF_ASSIGN_OR_RETURN(UserComputation * computation, + computation_tracker_.Resolve(arg->computation())); + + TF_ASSIGN_OR_RETURN(*result->mutable_shape(), + computation->GetShape(arg->operand())); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::GetComputationStats( + const ComputationStatsRequest* arg, ComputationStatsResponse* result) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(arg->computation())); + + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + computation_tracker_.BuildHloModule(versioned_handle)); + + MakeHloDumper()(*module, "computation statistics subject"); + + // Run HLO analysis to get the computation statistics. + HloCostAnalysis analysis; + + TF_RETURN_IF_ERROR( + module->entry_computation()->root_instruction()->Accept(&analysis)); + + ComputationStats stats; + stats.set_flop_count(analysis.flop_count()); + stats.set_transcendental_count(analysis.transcendental_count()); + *result->mutable_stats() = stats; + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::CheckRunsInClientProcess( + const string& method_name) const { + if (runs_in_client_process_) { + return tensorflow::Status::OK(); + } else { + return FailedPrecondition( + "%s only supported if service runs in the same process as the client", + method_name.c_str()); + } +} + +template +tensorflow::Status Service::AddInstruction( + const RequestT* arg, ResponseT* result, + const std::function(UserComputation*)>& + adder) { + TF_ASSIGN_OR_RETURN(UserComputation * computation, + computation_tracker_.Resolve(arg->computation())); + + TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { + TF_ASSIGN_OR_RETURN(UserComputation * computation, + computation_tracker_.Resolve(arg->computation())); + StatusOr handle; + + switch (arg->op_case()) { + case OpRequest::kBinaryOpRequest: + handle = computation->AddBinaryInstruction(arg->binary_op_request()); + break; + case OpRequest::kBroadcastRequest: + handle = computation->AddBroadcastInstruction(arg->broadcast_request()); + break; + case OpRequest::kCallRequest: { + TF_ASSIGN_OR_RETURN( + UserComputation * to_apply, + computation_tracker_.Resolve(arg->call_request().to_apply())); + handle = computation->AddCallInstruction(arg->call_request(), *to_apply); + break; + } + case OpRequest::kConcatenateRequest: + handle = + computation->AddConcatenateInstruction(arg->concatenate_request()); + break; + case OpRequest::kConstantRequest: + handle = computation->AddConstantInstruction(arg->constant_request()); + break; + case OpRequest::kConvertRequest: + handle = computation->AddConvertInstruction(arg->convert_request()); + break; + case OpRequest::kConvolveRequest: + handle = computation->AddConvolveInstruction(arg->convolve_request()); + break; + case OpRequest::kCrossReplicaSumRequest: + handle = computation->AddCrossReplicaSumInstruction( + arg->cross_replica_sum_request()); + break; + case OpRequest::kCustomCallRequest: + handle = + computation->AddCustomCallInstruction(arg->custom_call_request()); + break; + case OpRequest::kDynamicSliceRequest: + handle = + computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); + break; + case OpRequest::kDynamicUpdateSliceRequest: + handle = computation->AddDynamicUpdateSliceInstruction( + arg->dynamic_update_slice_request()); + break; + case OpRequest::kGetTupleElementRequest: + handle = computation->AddGetTupleElementInstruction( + arg->get_tuple_element_request()); + break; + case OpRequest::kInfeedRequest: + handle = computation->AddInfeedInstruction(arg->infeed_request()); + break; + case OpRequest::kMapRequest: { + TF_ASSIGN_OR_RETURN( + UserComputation * to_apply, + computation_tracker_.Resolve(arg->map_request().to_apply())); + handle = computation->AddMapInstruction(arg->map_request(), *to_apply); + break; + } + case OpRequest::kPadRequest: + handle = computation->AddPadInstruction(arg->pad_request()); + break; + case OpRequest::kParameterRequest: + handle = computation->AddParameterInstruction(arg->parameter_request()); + break; + case OpRequest::kReduceRequest: { + TF_ASSIGN_OR_RETURN( + UserComputation * to_apply, + computation_tracker_.Resolve(arg->reduce_request().to_apply())); + handle = + computation->AddReduceInstruction(arg->reduce_request(), *to_apply); + break; + } + case OpRequest::kReduceWindowRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * to_apply, + computation_tracker_.Resolve( + arg->reduce_window_request().to_apply())); + handle = computation->AddReduceWindowInstruction( + arg->reduce_window_request(), *to_apply); + break; + } + case OpRequest::kReshapeRequest: + handle = computation->AddReshapeInstruction(arg->reshape_request()); + break; + case OpRequest::kReverseRequest: + handle = computation->AddReverseInstruction(arg->reverse_request()); + break; + case OpRequest::kRngRequest: + handle = computation->AddRngInstruction(arg->rng_request()); + break; + case OpRequest::kSelectAndScatterRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * select, + computation_tracker_.Resolve( + arg->select_and_scatter_request().select())); + TF_ASSIGN_OR_RETURN(UserComputation * scatter, + computation_tracker_.Resolve( + arg->select_and_scatter_request().scatter())); + handle = computation->AddSelectAndScatterInstruction( + arg->select_and_scatter_request(), *select, *scatter); + break; + } + case OpRequest::kSliceRequest: + handle = computation->AddSliceInstruction(arg->slice_request()); + break; + case OpRequest::kTernaryOpRequest: + handle = computation->AddTernaryInstruction(arg->ternary_op_request()); + break; + case OpRequest::kTraceRequest: + return computation->AddTraceInstruction(arg->trace_request()); + case OpRequest::kUnaryOpRequest: + handle = computation->AddUnaryInstruction(arg->unary_op_request()); + break; + case OpRequest::kVariadicOpRequest: + handle = computation->AddVariadicInstruction(arg->variadic_op_request()); + break; + case OpRequest::kWhileRequest: { + TF_ASSIGN_OR_RETURN( + UserComputation * condition, + computation_tracker_.Resolve(arg->while_request().condition())); + TF_ASSIGN_OR_RETURN( + UserComputation * body, + computation_tracker_.Resolve(arg->while_request().body())); + handle = computation->AddWhileInstruction(arg->while_request(), + *condition, *body); + break; + } + case OpRequest::kSendRequest: { + TF_RETURN_IF_ERROR( + channel_tracker_.RegisterSend(arg->send_request().channel_handle())); + TF_RETURN_IF_ERROR(computation->AddSendInstruction(arg->send_request())); + return tensorflow::Status::OK(); + } + case OpRequest::kRecvRequest: { + TF_RETURN_IF_ERROR( + channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); + handle = computation->AddRecvInstruction(arg->recv_request()); + break; + } + default: + return InvalidArgument("Unsupported operation"); + } + TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle); + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::SnapshotComputation( + const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + computation_tracker_.SnapshotComputation(arg->computation())); + + result->set_allocated_module(module.release()); + + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::LoadComputationSnapshot( + const LoadComputationSnapshotRequest* arg, + LoadComputationSnapshotResponse* result) { + TF_ASSIGN_OR_RETURN(*result->mutable_computation(), + computation_tracker_.LoadSessionModule(arg->module())); + return tensorflow::Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h new file mode 100644 index 0000000000..1141e99fe3 --- /dev/null +++ b/tensorflow/compiler/xla/service/service.h @@ -0,0 +1,457 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/service/allocation_tracker.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/channel_tracker.h" +#include "tensorflow/compiler/xla/service/compilation_cache.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/execution_tracker.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/user_computation.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/service_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Options to configure the service when it is created. +class ServiceOptions { + public: + // Set the platform backing the service, or nullptr for the default platform. + ServiceOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // Set the number of replicas to use when compiling replicated + // programs. The default is -1 meaning that the value is read from + // the xla_replicas flag. + ServiceOptions& set_number_of_replicas(int number_of_replicas); + int number_of_replicas() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int number_of_replicas_ = -1; +}; + +// The XLA service object, which is the same across all +// platforms. It maintains the service state of computations and allocations, +// and delegates target-specific requests to the target-specific infrastructure +// (target-specific compiler, StreamExecutor). +class Service : public ServiceInterface { + public: + // Factory method for creating a new Service. + static StatusOr> NewService( + perftools::gputools::Platform* platform = nullptr); + static StatusOr> NewService( + const ServiceOptions& options); + + // Creates a new computation with the given name. + // A unique ComputationHandle is returned. + tensorflow::Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; + + // Unregisters a previously-allocated global handle. + // + // If the handle given is not currently allocated, a NOT_FOUND status is + // returned. + tensorflow::Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; + + // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each + // element in the tuple. + tensorflow::Status DeconstructTuple( + const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; + + // Modifies the provided computation so that subsequent executions + // will compute the provided ComputationDataHandle, rather than the + // last expression enqueued on that Computation. + tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; + + // Executes a computation with the provided global data passed as + // immutable arguments. Returns global data output and execution timing. + tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) override; + + // Executes one or more computations in parallel with the provided global data + // passed as immutable arguments. Returns global data output for each + // computation. + tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; + + // Requests one or more device handles from the target. + // + // When N device handles are requested and the number of replicas is R, at + // least N * R devices must be available. The devices are assigned based on + // the device ordinals such that the first R available devices are assigned to + // the first set of replicas, and the next R devices to the second set of + // replicas, etc. Each returned device handles represent the device with the + // replica id 0. + tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; + + // Asynchronously executes a computation with provided arguments. Invokes + // the provided computation with the provided global data passed as + // immutable arguments. Returns a handle to the execution. + tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; + + // Waits until the specified execution is complete and returns the result. + // Calling this API multiple times with the same execution handle returns the + // method with an error since the execution handle is destroyed after the + // first call. + tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; + + // Requests that global data be transferred to the client in literal form. + tensorflow::Status TransferToClient( + const TransferToClientRequest* arg, + TransferToClientResponse* result) override; + + // Requests that global data be copied into a buffer supplied by the client. + tensorflow::Status TransferToClientInProcess( + const TransferToClientInProcessRequest* arg, + TransferToClientInProcessResponse* result) override; + + // Transfers data from a literal provided by the client, into device memory. + tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, + TransferToServerResponse* result) override; + + // Transfers data from a literal provided by the client, into the Infeed + // buffer of the device. + tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; + + // Resets the device, clearing all existing state on the device. + tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; + + // Transfers data from a buffer provided by the client, into device memory. + tensorflow::Status TransferToServerInProcess( + const TransferToServerInProcessRequest* arg, + TransferToServerInProcessResponse* result) override; + + // Tests if an expression is a compile-time constant. + tensorflow::Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; + + // Computes the value of a constant expression. + tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; + + // Returns the shape (with layout) of an array associated with a given data + // handle. + tensorflow::Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; + + // Returns the program shape of the computation associated with the given + // handle. + tensorflow::Status GetComputationShape( + const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; + + ///// + // Computation-oriented methods. + + // Enqueues an Op on the computation. + tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; + + // Retrieves the inferred shape for a value within a computation. + tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; + + // Retrieves the statistics of a computation. + tensorflow::Status GetComputationStats( + const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; + + // Snapshots the current state of a computation handle into a serializable + // protocol buffer form, so it can be loaded via + // LoadComputationSnapshot. + tensorflow::Status SnapshotComputation( + const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) override; + + // Loads a computation from a serialized protocol buffer created via + // SnapshotComputation. + tensorflow::Status LoadComputationSnapshot( + const LoadComputationSnapshotRequest* arg, + LoadComputationSnapshotResponse* result) override; + + // Creates a unique channel handle that can be used for Send/Recv + // instructions. + tensorflow::Status CreateChannelHandle( + const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; + + // Returns the ComputationTracker of the current service instance. + // Only used in unit tests to access user computations from client. + const ComputationTracker& computation_tracker() { + return computation_tracker_; + } + + // Returns the backend used to execute computations. + const Backend& backend() const { return *execute_backend_; } + Backend* mutable_backend() { return execute_backend_.get(); } + + protected: + // The constructor is private. Use the NewService factory to create new + // service objects. + Service(std::unique_ptr backend, + std::unique_ptr compute_constant_backend); + + static StatusOr> CreateComputeConstantBackend(); + + // Resolves the given argument handles in the allocation tracker and returns + // the corresponding allocations. The function also verifies that each + // allocation matches the given backend and device ordinal. + StatusOr> ResolveAndValidateArguments( + tensorflow::gtl::ArraySlice arguments, + const Backend* backend, int device_ordinal); + + // Create a Hlo module config foe the given program shape and arguments. If + // shape_with_output_layout is not null, then the computation output layout is + // set to the layout of the given shape. + StatusOr> CreateModuleConfig( + const ProgramShape& program_shape, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout, uint64 seed); + + // Builds an Executable for the given parameters. If + // executable_for_compute_constant is true, then the executable is intended to + // be used for ComputeConstant which means dead parameter instructions are not + // included in the executable.The parameter "profile" can optionally point to + // an ExecutionProfile object which will be filled in with profile data + // relevant to compilation. + StatusOr> BuildExecutable( + const VersionedComputationHandle& versioned_handle, + std::unique_ptr module_config, + bool executable_for_compute_constant, + const tensorflow::gtl::ArraySlice + arguments, + Backend* backend, perftools::gputools::StreamExecutor* executor); + + // Same as BuildExecutable() above, but builds a list of Executables for the + // given computations that may interact with each other. + StatusOr>> BuildExecutables( + std::vector versioned_handles, + std::vector> module_configs, + Backend* backend, + std::vector executors); + + // Similar to BuildExecutable, but look in the compilation cache for the + // executable first. If the executable is not in the cache, it is built and + // inserted into the cache. + StatusOr> BuildAndCacheExecutable( + const VersionedComputationHandle& versioned_handle, + std::unique_ptr module_config, + const tensorflow::gtl::ArraySlice + arguments, + Backend* backend, perftools::gputools::StreamExecutor* executor, + ExecutionProfile* profile); + + // Runs the given executable with the given arguments and register the result + // in the allocation tracker. The handle of the result from the tracker is + // returned. If the parameter "profile" is not null, it points to an + // ExecutionProfile object which will be filled in with profile data. + StatusOr ExecuteAndRegisterResult( + Executable* executable, + const tensorflow::gtl::ArraySlice + arguments, + Backend* backend, perftools::gputools::StreamExecutor* executor, + const string& result_tag, ExecutionProfile* profile); + + // Runs the given executables with the given arguments and register the result + // from each executable in the allocation tracker. The handles of the result + // from the tracker are returned. + StatusOr> ExecuteParallelAndRegisterResult( + tensorflow::gtl::ArraySlice executables, + tensorflow::gtl::ArraySlice< + std::vector> + arguments, + Backend* backend, + tensorflow::gtl::ArraySlice + executors, + tensorflow::gtl::ArraySlice result_tags); + + // Dumps the executed HLO according to service-associated flags. + static void DumpExecutedHlo(const HloModule& module, const string& label, + const HloExecutionProfile* profile); + + // Returns an HLO dumper for use in the compiler (it refers to flags + // associated with the service). + static Compiler::HloDumper MakeHloDumper(); + + // Convenience function for adding a function to a user computation. + template + tensorflow::Status AddInstruction( + const RequestT* arg, ResponseT* result, + const std::function(UserComputation*)>& + adder); + + // If the service is running in the client process + // (runs_in_client_process_ is true) then return + // tensorflow::Status::OK. Otherwise return an appropriate error + // status with the given method name. Used for "InProcess" methods. + tensorflow::Status CheckRunsInClientProcess(const string& method_name) const; + + // Convenience function which checks whether the given shape_with_layout + // (presumably passed by the client to set the result layout) is valid for the + // given computation result shape. + tensorflow::Status ValidateResultShapeWithLayout( + const Shape& shape_with_layout, const Shape& result_shape) const; + + // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a + // timer for the execution, sets up HLO profiling if enabled, and fills in the + // given ExecutionProfile if non-null. The given execute_func should be a + // function which calls the desired ExecuteOnStream overload with the supplied + // arguments. The ExecuteOnStream overloads return different types so this + // method is templated on return-type of the execute function. + template + ReturnT ExecuteOnStreamWrapper( + Executable* executable, const ExecutableRunOptions* run_options, + ExecutionProfile* profile, + std::function + execute_func); + + // Tracks computations built via the API. + ComputationTracker computation_tracker_; + + // Tracks channels created via the API. + ChannelTracker channel_tracker_; + + // Tracks allocations made via the API and computation execution. + AllocationTracker allocation_tracker_; + + // Tracks asynchronously launched executions via the API. + ExecutionTracker execution_tracker_; + + // Cache containing previously built Executables. + CompilationCache compilation_cache_; + + // Backend to compile and execute computations on. + // + // TODO(b/28616830): Support multiple backends for execution. + std::unique_ptr execute_backend_; + + // Backend to use when executing ComputeConstant. + std::unique_ptr compute_constant_backend_; + + // Whether the service runs in the same process as the client. + bool runs_in_client_process_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(Service); +}; + +template +ReturnT Service::ExecuteOnStreamWrapper( + Executable* executable, const ExecutableRunOptions* run_options, + ExecutionProfile* profile, + std::function + execute_func) { + perftools::gputools::Stream* stream = run_options->stream(); + std::unique_ptr timer; + if (profile != nullptr) { + timer.reset(new perftools::gputools::Timer(stream->parent())); + stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); + } + + VLOG(1) << "enqueueing executable on stream..."; + // If the profiling flag isn't enabled, we pass nullptr as the profile to + // indicate profiling is not requested. + HloExecutionProfile hlo_execution_profile; + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + HloExecutionProfile* profile_ptr = + flags->xla_hlo_profile && executable->hlo_profiling_enabled() + ? &hlo_execution_profile + : nullptr; + + auto return_value = execute_func(executable, run_options, profile_ptr); + + if (profile != nullptr) { + VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; + stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); + VLOG(1) << "done with block-host-until-done"; + + // Merge in run time profile information from the executable. + profile->MergeFrom(executable->execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + + if (profile_ptr != nullptr) { + HloCostAnalysis analysis; + tensorflow::Status analysis_status = + executable->module().entry_computation()->root_instruction()->Accept( + &analysis); + if (analysis_status.ok()) { + XLA_LOG_LINES(tensorflow::INFO, + profile_ptr->ToString( + stream->parent()->GetDeviceDescription(), analysis)); + } + DumpExecutedHlo(executable->module(), "Service::Execute", profile_ptr); + } + + return return_value; +} +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto new file mode 100644 index 0000000000..ead3c1eb10 --- /dev/null +++ b/tensorflow/compiler/xla/service/session.proto @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This proto file defines messages which store the state of XLA +// computations within the XLA service. A computation is stored as a record +// of the operation requests used to build it. +syntax = "proto3"; + +import "tensorflow/compiler/xla/xla_data.proto"; + +package xla; + +// Describes a single operation request. +message OperationRequest { + ComputationDataHandle output_handle = 1; + Shape output_shape = 2; + + // For operations which call embedded computations such as "Map", these are + // the version(s) that the embedded computation should be called at. A version + // value of a computation is the ComputationDataHandle of the root of the + // computation at the point in time. + // + // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single + // embedded computation so this field will have a single value for those + // operations. + // + // "While" operation takes two; index 0 is the "condition" version and index 1 + // is the "body" version. + repeated int64 embedded_computation_versions = 3; + + // The actual request, which in itself is a tagged union of all possible + // operation request types. + OpRequest request = 4; +} + +// Describes a sequence of operation requests which define an XLA +// computation. +message SessionComputation { + string name = 1; + + // The ComputationHandle used to refer to this computation in the XLA + // service. + ComputationHandle computation_handle = 2; + + // Map from ComputationDataHandle value to operation request. The highest + // ComputationDataHandle value corresponds to the root of the computation. + map requests = 3; + + // The list of Trace requests in this SessionComputation. + repeated TraceRequest trace_requests = 4; + + // The list of Send requests in this SessionComputation. + repeated SendRequest send_requests = 5; +} + +// Describes a group of SessionComputations with an "entry point" computation +// that may refer to the other non-entry (AKA embedded) computations. +// +// This message is used to serialize a computation that has been built via the +// XLA service API, along with its dependencies, for purposes such as +// analysis/replay/file-storage. +message SessionModule { + // The entry computation, which was requested for serialization. This may have + // referred to embedded computations, which are reflected below. + SessionComputation entry = 1; + + // Embedded computations that are transitively referred to by the entry + // computation. + repeated SessionComputation embedded_computations = 2; + + // The arguments passed to the computation. + repeated Literal arguments = 3; + + // The result of the computation. + Literal result = 4; + + // The name of the platform used to run the computation. + string execution_platform = 5; +} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc new file mode 100644 index 0000000000..11559ad757 --- /dev/null +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -0,0 +1,1380 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shape_inference.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { + +namespace { + +// Returns true if no element is present in slice more than once. +bool AllUnique(tensorflow::gtl::ArraySlice slice) { + return std::set(slice.begin(), slice.end()).size() == slice.size(); +} + +tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { + if (ShapeUtil::IsTuple(shape)) { + return InvalidArgument("Expected non-tuple argument for %s. Got: %s", + op_type.ToString().c_str(), + ShapeUtil::HumanString(shape).c_str()); + } else if (ShapeUtil::IsOpaque(shape)) { + return InvalidArgument("Expected non-opaque argument for %s. Got: %s", + op_type.ToString().c_str(), + ShapeUtil::HumanString(shape).c_str()); + } else { + return tensorflow::Status::OK(); + } +} + +tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { + if (reducer_shape.parameters_size() != 2) { + return InvalidArgument( + "Reduction function must take 2 parameters, but " + "takes %d parameter(s).", + reducer_shape.parameters_size()); + } + + const Shape& accumulator_shape = reducer_shape.result(); + if (ShapeUtil::Rank(accumulator_shape) != 0) { + return Unimplemented( + "Reduction function currently must have rank-0 result."); + } + + // Check that the accumulator can be passed in as the first argument. + // Note: comparing here and below with Compatible since we don't care about + // layout in scalars - see b/26668201 for a longer-term vision. + if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) { + return InvalidArgument( + "Reduction function's first parameter shape differs from the " + "result shape: %s vs %s", + ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(), + ShapeUtil::HumanString(accumulator_shape).c_str()); + } + + // Check that init_value's shape is suitable for reducer_shape. + if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) { + return InvalidArgument( + "Reduction function's accumulator shape differs from the " + "init_value shape: %s vs %s", + ShapeUtil::HumanString(accumulator_shape).c_str(), + ShapeUtil::HumanString(init_value_shape).c_str()); + } + + // Check that the inputs can be passed in as the second argument. + const Shape& input_element_shape = + ShapeUtil::MakeShape(input_element_type, {}); + if (!ShapeUtil::Compatible(input_element_shape, + reducer_shape.parameters(1))) { + return InvalidArgument( + "Reduction function's second parameter shape differs from the " + "input type element type: %s vs %s", + ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), + ShapeUtil::HumanString(input_element_shape).c_str()); + } + + // Currently the accumulator and inputs must be the same type, + // though that restriction could be relaxed. + if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) { + return InvalidArgument( + "Reduction function's second parameter shape currently must " + "match the result shape. Got %s vs %s", + ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), + ShapeUtil::HumanString(accumulator_shape).c_str()); + } + + return tensorflow::Status::OK(); +} + +StatusOr InferWindowOutputShape(const Shape& base_shape, + const Window& window, + PrimitiveType element_type, + bool allow_negative_padding) { + if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { + return InvalidArgument( + "Window has dimension %d but base shape has dimension %lld.", + window.dimensions_size(), ShapeUtil::Rank(base_shape)); + } + + std::vector output_dimensions(window.dimensions_size()); + for (int64 i = 0; i < window.dimensions_size(); ++i) { + const auto& dim = window.dimensions(i); + if (dim.size() <= 0) { + return InvalidArgument("Window has a non-positive dimension. Window: %s", + window.DebugString().c_str()); + } + if (dim.stride() <= 0) { + return InvalidArgument("Window has a non-positive stride. Window: %s", + window.DebugString().c_str()); + } + if (!allow_negative_padding && dim.padding_low() < 0) { + return InvalidArgument("Window has a negative low padding. Window: %s", + window.DebugString().c_str()); + } + if (!allow_negative_padding && dim.padding_high() < 0) { + return InvalidArgument("Window has a negative high padding. Window: %s", + window.DebugString().c_str()); + } + if (dim.base_dilation() < 1) { + return InvalidArgument( + "Window has a non-positive base area dilation factor. Window: %s", + window.DebugString().c_str()); + } + if (dim.window_dilation() < 1) { + return InvalidArgument( + "Window has a non-positive window dilation factor. Window: %s", + window.DebugString().c_str()); + } + + const int64 dilated_base = window_util::DilatedBound( + ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); + const int64 padded_dilated_base = + dim.padding_low() + dilated_base + dim.padding_high(); + const int64 dilated_window = + window_util::DilatedBound(dim.size(), dim.window_dilation()); + + output_dimensions[i] = window_util::StridedBound( + padded_dilated_base, dilated_window, dim.stride()); + } + + return ShapeUtil::MakeShape(element_type, output_dimensions); +} + +} // namespace + +/* static */ StatusOr ShapeInference::InferUnaryOpShape( + UnaryOperation operation, const Shape& arg) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); + + TF_DCHECK_OK(ShapeUtil::ValidateShape(arg)); + switch (operation) { + case UNOP_FLOOR: + case UNOP_CEIL: + case UNOP_EXP: + case UNOP_LOG: + case UNOP_TANH: + if (!ShapeUtil::ElementIsFloating(arg)) { + return InvalidArgument( + "expected element type in shape to be floating for exp/log/tanh " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return arg; + case UNOP_ABS: + case UNOP_SIGN: + case UNOP_NEGATE: + case UNOP_SORT: + return arg; + + case UNOP_LOGICAL_NOT: + if (arg.element_type() != PRED) { + return InvalidArgument( + "expected pred element type in argument to logical-not operation; " + "got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return arg; + default: + return InvalidArgument("unknown operation %s", + UnaryOperation_Name(operation).c_str()); + } +} + +/* static */ StatusOr ShapeInference::InferConcatOpShape( + tensorflow::gtl::ArraySlice arg_shapes, + const int64 dimension) { + if (arg_shapes.size() == 0) { + return InvalidArgument("Concatenate expects at least one argument"); + } + if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { + return InvalidArgument("dimension to concatenate along out of bounds: %lld", + dimension); + } + const Shape* arg_shape = nullptr; + for (const Shape* shape : arg_shapes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); + if (!arg_shape) { + arg_shape = shape; + continue; + } + if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { + return InvalidArgument( + "cannot concatenate arrays with different ranks: %lld vs %lld", + ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); + } + if (arg_shape->element_type() != shape->element_type()) { + return InvalidArgument( + "cannot concatenate arrays with different element types: %s vs %s", + PrimitiveType_Name(arg_shape->element_type()).c_str(), + PrimitiveType_Name(shape->element_type()).c_str()); + } + for (int64 dimension_number = 0; + dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { + if (arg_shape->dimensions(dimension_number) != + shape->dimensions(dimension_number)) { + if (dimension_number == dimension) { + continue; // It's okay to differ in the dimension we're + // concatenating. + } + return InvalidArgument( + "cannot concatenate arrays that differ in dimensions other than " + "the one being concatenated (the other array dimensions must be " + "the same): %s vs %s", + ShapeUtil::HumanString(*arg_shape).c_str(), + ShapeUtil::HumanString(*shape).c_str()); + } + } + } + + std::vector new_dimensions(arg_shape->dimensions().begin(), + arg_shape->dimensions().end()); + for (size_t i = 1; i < arg_shapes.size(); ++i) { + new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); + } + return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type) { + if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + // Note: we may want to support tuple conversions via this operation in the + // future, by recursing into the tuple elements to check all sub-conversions + // are valid. For now we just reject them, though. + return InvalidArgument( + "cannot convert from or to tuple type; requested conversion: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + + return ShapeUtil::ChangeElementType(operand_shape, new_element_type); +} + +/* static */ StatusOr ShapeInference::InferPadShape( + const Shape& operand_shape, const Shape& padding_value_shape, + const PaddingConfig& padding_config) { + if (ShapeUtil::IsTuple(operand_shape)) { + return InvalidArgument( + "pad operation does not support tuple-shape operands"); + } + if (!ShapeUtil::IsScalar(padding_value_shape)) { + return InvalidArgument( + "pad operation does not support non-scalar padding values"); + } + if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { + return InvalidArgument( + "the rank of the operand and the padding configuration do not match."); + } + std::vector dimensions(ShapeUtil::Rank(operand_shape)); + for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { + dimensions[i] = operand_shape.dimensions(i) + + padding_config.dimensions(i).edge_padding_low() + + padding_config.dimensions(i).edge_padding_high() + + std::max(operand_shape.dimensions(i) - 1, 0LL) * + padding_config.dimensions(i).interior_padding(); + } + return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); +} + +/* static */ StatusOr ShapeInference::InferDotOpShape(const Shape& lhs, + const Shape& rhs) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); + + auto fail = [lhs, rhs](const string& addendum) -> Status { + string message = tensorflow::strings::Printf( + "cannot infer shape for dot operation: %s %s", + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); + if (!addendum.empty()) { + message += ": " + addendum; + } + return InvalidArgument("%s", message.c_str()); + }; + + // Check if both element types are the same. + if (lhs.element_type() != rhs.element_type()) { + return fail("element types mismatch"); + } + + if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || + ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { + return fail("dot only supports rank 1 or 2"); + } + + // Determine the index of the contracted dimensions for input tensors. + // dimensions -1 of lhs and dimension 0 of rhs are contracted. + int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); + int64 rhs_contracted_dimension = 0; + + // Check if the contracted dimension sizes are the same. + if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && + rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && + lhs.dimensions(lhs_contracted_dimension) != + rhs.dimensions(rhs_contracted_dimension)) { + return fail("contracted dimensions mismatch"); + } + + // The ranks of lhs and rhs are decremented by 1 respectively due to the + // contraction, and added for the rank of the result. When an input tensor is + // a scalar, its contribution to the rank of the result is 0. + // Generate the result dimensions in order, rhs dimensions followed by lhs + // dimensions except the contracted dimensions. + std::vector dimensions; + for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { + if (i != lhs_contracted_dimension) { + dimensions.push_back(lhs.dimensions(i)); + } + } + for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { + if (i != rhs_contracted_dimension) { + dimensions.push_back(rhs.dimensions(i)); + } + } + Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions); + + TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); + return result; +} + +/* static */ StatusOr +ShapeInference::InferDegenerateDimensionBroadcastShape( + BinaryOperation operation, const Shape& lhs, const Shape& rhs) { + TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); + + // The shapes have to be compatible. That is, if some dimension d has a + // different size in the two shapes, one of them has to be 1 (a "degenerate" + // dimension). In that case, the output shape has the non-1 dimension size + // from the lhs/rhs pair in every index. + std::vector output_dimensions(ShapeUtil::Rank(lhs)); + for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) { + if (lhs.dimensions(i) == rhs.dimensions(i)) { + output_dimensions[i] = lhs.dimensions(i); + } else if (lhs.dimensions(i) == 1) { + output_dimensions[i] = rhs.dimensions(i); + } else if (rhs.dimensions(i) == 1) { + output_dimensions[i] = lhs.dimensions(i); + } else { + return InvalidArgument("binary op %s with incompatible shapes: %s and %s", + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); + } + } + return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions); +} + +/* static */ StatusOr ShapeInference::InferInDimBroadcastShape( + BinaryOperation operation, const Shape& smaller_shape, + const Shape& larger_shape, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { + // Reject "magic" inference for binops on different shapes, requiring + // the user to provide an explicit broadcast dimension in this case. + // See b/25177275 for more details. + return InvalidArgument("automatic shape inference not supported: %s and %s", + ShapeUtil::HumanString(smaller_shape).c_str(), + ShapeUtil::HumanString(larger_shape).c_str()); + } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { + return InvalidArgument( + "size of broadcast_dimensions has to match lower-rank operand's " + "rank; " + " lower-rank operand's rank is %lld, size of broadcast_dimensions is " + "%zu", + ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); + } + + // broadcast_dimensions is a sequence of dimensions; its length is equal to + // the rank of the lower-rank operand. The lower-rank operand's dimensions + // have to be compatible with the higher-rank operand's dimensions at indices + // specified by broadcast_dimensions. Here compatible means the dimension + // sizes are equal or in one of the shapes the dimension size is + // one. Examples: + // + // smaller_shape larger_shape broadcast_dimensions output_shape + // [] [2, 3] {} [2, 3] + // [3] [4, 3] {1} [4, 3] + // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4] + // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1] + // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4] + // + // The column output_shape may not be the final shape of the XLA + // operation. After the "InDim" broadcasting implemented in this function + // expands the rank, degenerate-dimension broadcasting (implemented in + // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one + // up to match the dimension size of the other operand. For example, consider + // the row in the table above with a smaller_shape of [2, 1]. The shape + // returned by this function is [2, 3, 1] (output_shape) however, the result + // shape of the XLA operation is [2, 3, 4] after degenerate-dimension + // broadcasting. + // + // Invalid broadcasts: + // + // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0} + // Reason: Dimension zero** of larger_shape (size 4) is not compatible with + // dimension zero of smaller_shape(size 3). **Zero here comes from the value + // in broadcast_dimensions. + // + // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2} + // Reason: Dimension one of larger_shape (size 3) is not compatible with + // dimension zero of smaller_shape(size 2) + + // The output shape is initially the larger_shape. Sizes of dimensions + // specified in broadcast_dimensions are then changed to match the + // corresponding dimension size in smaller_shape. + Shape output_shape(larger_shape); + + for (int i = 0; i < smaller_shape.dimensions_size(); ++i) { + int64 dimension_to_match = broadcast_dimensions.at(i); + if (dimension_to_match < 0) { + return InvalidArgument( + "broadcast dimension number (%lld) cannot be negative", + dimension_to_match); + } + if (dimension_to_match >= larger_shape.dimensions_size()) { + return InvalidArgument( + "broadcast dimension number (%lld) too large; higher-rank " + "operand has rank %d", + dimension_to_match, larger_shape.dimensions_size()); + } + int64 small_dimension_size = smaller_shape.dimensions(i); + int64 large_dimension_size = larger_shape.dimensions(dimension_to_match); + // Dimension sizes must be compatible: match or be degenerate (degenerate + // case is handled by degenerate dimension broadcasting which occurs after + // InDim broadcasting). + if (small_dimension_size != large_dimension_size && + small_dimension_size != 1 && large_dimension_size != 1) { + return InvalidArgument( + "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i, + small_dimension_size, large_dimension_size, + ShapeUtil::HumanString(smaller_shape).c_str(), + ShapeUtil::HumanString(larger_shape).c_str()); + } + // Make sure the broadcast dimensions are listed in a strictly increasing + // order. + if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { + return InvalidArgument( + "broadcast dimensions order is wrong: %lld comes after %lld", + dimension_to_match, broadcast_dimensions.at(i - 1)); + } + + output_shape.set_dimensions(dimension_to_match, small_dimension_size); + } + + return output_shape; +} + +/* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( + BinaryOperation operation, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); + + if (!ShapeUtil::SameElementType(lhs, rhs)) { + return InvalidArgument("binary op with different element types: %s and %s", + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); + } + + if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && + !broadcast_dimensions.empty()) { + return InvalidArgument( + "broadcast dimensions field should not be set on binary " + "operations with operands of the same rank"); + } + + if (ShapeUtil::Compatible(lhs, rhs)) { + // If the shapes are the same other than layout, the output shape is the + // same (elementwise op). + return lhs; + } + + if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { + return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); + } else { + // Ranks do not match, so perform InDim broadcasting using + // broadcast_dimensions. Scalar broadcasting is a special case of this). + const Shape& larger_shape = + ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs; + const Shape& smaller_shape = + ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; + + // After InDim broadcasting, perform degenerate dimensions broadcasting. + TF_ASSIGN_OR_RETURN( + Shape indim_broadcast_shape, + InferInDimBroadcastShape(operation, smaller_shape, larger_shape, + broadcast_dimensions)); + + return InferDegenerateDimensionBroadcastShape( + operation, indim_broadcast_shape, larger_shape); + } +} + +/* static */ StatusOr ShapeInference::InferBinaryOpShape( + BinaryOperation operation, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + VLOG(2) << tensorflow::strings::Printf( + "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str()); + TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs)); + + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation")); + switch (operation) { + case BINOP_DOT: + return InferDotOpShape(lhs, rhs); + case BINOP_MAX: + case BINOP_MIN: + case BINOP_SUB: + case BINOP_ADD: + case BINOP_POW: + case BINOP_DIV: + case BINOP_REM: + case BINOP_MUL: + return InferElementwiseBinaryOpShape(operation, lhs, rhs, + broadcast_dimensions); + + case BINOP_LOGICAL_AND: + case BINOP_LOGICAL_OR: + if (lhs.element_type() != PRED) { + return InvalidArgument( + "expected pred element type in argument to logical and/or " + "operation; got %s", + PrimitiveType_Name(lhs.element_type()).c_str()); + } + return InferElementwiseBinaryOpShape(operation, lhs, rhs, + broadcast_dimensions); + + case BINOP_EQ: + case BINOP_GE: + case BINOP_GT: + case BINOP_LE: + case BINOP_LT: + case BINOP_NE: { + TF_ASSIGN_OR_RETURN(const Shape& shape, + InferElementwiseBinaryOpShape(operation, lhs, rhs, + broadcast_dimensions)); + return ShapeUtil::ChangeElementType(shape, PRED); + } + case BINOP_INDEX: + if (ShapeUtil::Rank(lhs) > 0 && ShapeUtil::Rank(rhs) == 0) { + tensorflow::gtl::ArraySlice dimensions = + AsInt64Slice(lhs.dimensions()); + dimensions.pop_front(); + return ShapeUtil::MakeShape(lhs.element_type(), dimensions); + } + return Unimplemented("cannot infer shape for operation: %s <%s> %s", + ShapeUtil::HumanString(lhs).c_str(), + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(rhs).c_str()); + default: + return Unimplemented( + "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s", + BinaryOperation_Name(operation).c_str(), + lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + } +} + +/* static */ StatusOr ShapeInference::InferTernaryOpShape( + TernaryOperation operation, const Shape& lhs, const Shape& rhs, + const Shape& ehs) { + TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs)); + switch (operation) { + case TRIOP_CLAMP: + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); + if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) && + (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) { + return rhs; + } + if (ShapeUtil::Rank(rhs) == 0) { + if (ShapeUtil::Compatible(lhs, ehs)) { + return lhs; + } + return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs; + } + return Unimplemented("not yet implemented: %s, %s %s", + lhs.ShortDebugString().c_str(), + ehs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); + case TRIOP_SELECT: + return InferSelectShape(lhs, rhs, ehs); + case TRIOP_UPDATE: + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); + return lhs; + default: + return InvalidArgument("unknown operation %s", + TernaryOperation_Name(operation).c_str()); + } +} + +/* static */ StatusOr ShapeInference::InferVariadicOpShape( + VariadicOperation operation, std::vector operand_shapes) { + for (const Shape* shape : operand_shapes) { + TF_DCHECK_OK(ShapeUtil::ValidateShape(*shape)); + } + switch (operation) { + case VAROP_TUPLE: { + Shape result = ShapeUtil::MakeTupleShape({}); + for (const Shape* shape : operand_shapes) { + ShapeUtil::AppendShapeToTuple(*shape, &result); + } + return result; + } + default: + return InvalidArgument("unknown operation %s", + VariadicOperation_Name(operation).c_str()); + } +} + +/* static */ StatusOr ShapeInference::InferMapShape( + tensorflow::gtl::ArraySlice arg_shapes, + const ProgramShape& to_apply) { + if (arg_shapes.size() == 0) { + return InvalidArgument("Map expects at least one argument"); + } + + // All arguments must have the same shape. + const Shape* arg_shape = arg_shapes[0]; + for (size_t i = 1; i < arg_shapes.size(); ++i) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); + + if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) { + continue; + } + if (!ShapeUtil::IsTuple(*arg_shapes[i]) && + !ShapeUtil::IsTuple(*arg_shape) && + ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) { + if (ShapeUtil::IsScalar(*arg_shapes[i])) { + continue; + } + if (ShapeUtil::IsScalar(*arg_shape)) { + arg_shape = arg_shapes[i]; + continue; + } + } + + std::vector pieces; + for (const Shape* shape : arg_shapes) { + pieces.push_back(ShapeUtil::HumanString(*shape)); + } + return InvalidArgument( + "Map operation requires all operands to have the same shape; got: " + "%s", + tensorflow::str_util::Join(pieces, ", ").c_str()); + } + + // The applied function's arity equals the number of arguments. + if (arg_shapes.size() != to_apply.parameters_size()) { + return InvalidArgument( + "Map applied function arity must match number of arguments; got: " + "arity: %d, arguments: %zu", + to_apply.parameters_size(), arg_shapes.size()); + } + + // The parameters should all be scalars, and the output too. + const Shape& output_shape = to_apply.result(); + if (!ShapeUtil::IsScalar(output_shape)) { + return InvalidArgument( + "mapped computation's result has to be a scalar; " + "got: %s", + ShapeUtil::HumanString(output_shape).c_str()); + } + + for (int i = 0; i < to_apply.parameters_size(); ++i) { + const Shape& parameter_shape = to_apply.parameters(i); + + if (!ShapeUtil::IsScalar(parameter_shape)) { + return InvalidArgument( + "mapped computation's parameter has to be a scalar; " + "got parameter %d shape: %s", + i, ShapeUtil::HumanString(parameter_shape).c_str()); + } + + if (parameter_shape.element_type() != arg_shape->element_type()) { + return InvalidArgument( + "mapped computation's parameter type has to match argument element " + "type; got parameter %d shape: %s, argument shape: %s", + i, ShapeUtil::HumanString(parameter_shape).c_str(), + ShapeUtil::HumanString(*arg_shape).c_str()); + } + } + + return ShapeUtil::MakeShape(output_shape.element_type(), + AsInt64Slice(arg_shape->dimensions())); +} + +/* static */ StatusOr ShapeInference::InferConvolveShape( + const Shape& lhs, const Shape& rhs, const Window& window, + const ConvolutionDimensionNumbers& dnums) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); + + if (!ShapeUtil::SameElementType(lhs, rhs)) { + return InvalidArgument( + "Convolution with different element types: %s and %s", + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); + } + if (dnums.spatial_dimensions_size() != + dnums.kernel_spatial_dimensions_size()) { + return InvalidArgument( + "Both arguments to convolution must have same number of dimensions.\n" + "Window: %s", + window.DebugString().c_str()); + } + int num_spatial_dims = dnums.spatial_dimensions_size(); + if (num_spatial_dims < 1) { + return InvalidArgument( + "Convolution requires at least one spatial dimension.\n" + "Window: %s", + window.DebugString().c_str()); + } + + if (window.dimensions_size() != num_spatial_dims) { + return InvalidArgument( + "Window must have same number of dimensions as dimension numbers.\n" + "Window: %s\nDimension numbers: %s", + window.DebugString().c_str(), dnums.DebugString().c_str()); + } + + int num_dims = num_spatial_dims + 2; + if (ShapeUtil::Rank(lhs) != num_dims) { + return InvalidArgument( + "The LHS argument to a convolution should have rank %d.\n" + "lhs: %s", + num_dims, ShapeUtil::HumanString(lhs).c_str()); + } + if (ShapeUtil::Rank(rhs) != num_dims) { + return InvalidArgument( + "The RHS argument to a convolution should have rank %d.\n" + "lhs: %s", + num_dims, ShapeUtil::HumanString(lhs).c_str()); + } + TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs)); + + // Verifies that the input and window dimensions are a permutation of + // the dimension numbers. + std::vector input_dnums(num_dims); + input_dnums[0] = dnums.batch_dimension(); + input_dnums[1] = dnums.feature_dimension(); + std::copy(dnums.spatial_dimensions().begin(), + dnums.spatial_dimensions().end(), input_dnums.begin() + 2); + std::sort(input_dnums.begin(), input_dnums.end()); + + std::vector window_dnums(num_dims); + window_dnums[0] = dnums.kernel_input_feature_dimension(); + window_dnums[1] = dnums.kernel_output_feature_dimension(); + std::copy(dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); + std::sort(window_dnums.begin(), window_dnums.end()); + + std::vector expected_dnums(num_dims); + std::iota(expected_dnums.begin(), expected_dnums.end(), 0); + + const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; + if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || + !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) { + return InvalidArgument( + "A dimension number is out of range in convolution: %s", + dnums.DebugString().c_str()); + } + + if (input_dnums != expected_dnums) { + return InvalidArgument( + "Input dimensions of convolution must contain each dimension exactly " + "once: %s", + dnums.DebugString().c_str()); + } + if (window_dnums != expected_dnums) { + return InvalidArgument( + "Window dimensions of convolution must contain each dimension exactly " + "once: %s", + dnums.DebugString().c_str()); + } + + std::vector input_spatial_dims(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); + } + const int64 input_features = lhs.dimensions(dnums.feature_dimension()); + const int64 input_batch = lhs.dimensions(dnums.batch_dimension()); + + std::vector kernel_spatial_dims(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i)); + } + const int64 kernel_input_features = + rhs.dimensions(dnums.kernel_input_feature_dimension()); + const int64 kernel_output_features = + rhs.dimensions(dnums.kernel_output_feature_dimension()); + + if (input_features != kernel_input_features) { + return InvalidArgument( + "Expected LHS feature dimension (value %lld) to match RHS " + "input feature dimension (value %lld); got (%s, %s)\n" + "Dimension numbers: {%s}", + input_features, kernel_input_features, + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); + } + std::vector window_dims(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + window_dims[i] = window.dimensions(i).size(); + } + if (kernel_spatial_dims != window_dims) { + return InvalidArgument( + "Window dimensions do not match RHS shape:\n\t" + "RHS shape: %s\n\t" + "Window: {%s}\n\t" + "Dimension numbers: {%s}", + ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), + dnums.ShortDebugString().c_str()); + } + + Shape base_shape = + ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims); + TF_ASSIGN_OR_RETURN( + Shape window_output_shape, + InferWindowOutputShape(base_shape, window, lhs.element_type(), + /*allow_negative_padding=*/true)); + + std::vector dimensions(num_dims); + dimensions[dnums.batch_dimension()] = input_batch; + dimensions[dnums.feature_dimension()] = kernel_output_features; + for (int i = 0; i < num_spatial_dims; ++i) { + dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); + } + + return ShapeUtil::MakeShape(lhs.element_type(), dimensions); +} + +/* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( + const Shape& operand) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand, "operand of cross replica sum")); + return operand; +} + +/* static */ StatusOr ShapeInference::InferReduceShape( + const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + const ProgramShape& to_apply) { + // Check that the dimension to reduce are in-bounds for the given shape. + for (int64 dimension : dimensions_to_reduce) { + if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { + return InvalidArgument( + "attempting to reduce out-of-bounds dimension %lld in shape %s", + dimension, ShapeUtil::HumanString(arg).c_str()); + } + } + TF_RETURN_IF_ERROR( + VerifyReducerShape(to_apply, init_value, arg.element_type())); + + std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + std::vector new_dimensions; + for (int i = 0; i < ShapeUtil::Rank(arg); ++i) { + if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { + new_dimensions.push_back(arg.dimensions(i)); + } + } + + return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferReduceWindowShape( + const Shape& operand_shape, const Shape& init_value_shape, + const Window& window, const ProgramShape& to_apply_shape) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window")); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, + operand_shape.element_type())); + return InferWindowOutputShape(operand_shape, window, + init_value_shape.element_type(), + /*allow_negative_padding=*/false); +} + +/* static */ StatusOr ShapeInference::InferSelectAndScatterShape( + const Shape& operand_shape, const ProgramShape& select_shape, + const Window& window, const Shape& source_shape, + const Shape& init_value_shape, const ProgramShape& scatter_shape) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter")); + + // Check if the select function has a proper shape of (T,T) -> PRED. + if (select_shape.parameters_size() != 2) { + return InvalidArgument( + "select function must take 2 parameters, but " + "takes %d parameter(s).", + select_shape.parameters_size()); + } + const Shape& select_result_shape = select_shape.result(); + if (!ShapeUtil::Compatible(select_result_shape, + ShapeUtil::MakeShape(PRED, {}))) { + return Unimplemented("select function must have rank-0 PRED result."); + } + const Shape& operand_element_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {}); + if (!ShapeUtil::Compatible(operand_element_shape, + select_shape.parameters(0))) { + return InvalidArgument( + "select function's first parameter shape currently must " + "match the operand element shape. Got %s vs %s", + ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), + ShapeUtil::HumanString(operand_element_shape).c_str()); + } + if (!ShapeUtil::Compatible(operand_element_shape, + select_shape.parameters(1))) { + return InvalidArgument( + "select function's second parameter shape currently must " + "match the operand element shape. Got %s vs %s", + ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), + ShapeUtil::HumanString(operand_element_shape).c_str()); + } + + // Check if the scatter function has a proper shape as a reduction. + TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape, + source_shape.element_type())); + + // Check if the result shape of window operation matches the source shape. + TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, + InferWindowOutputShape(operand_shape, window, + operand_shape.element_type(), + /*allow_negative_padding=*/false)); + if (!ShapeUtil::Compatible(source_shape, window_result_shape)) { + return InvalidArgument( + "source shape does not match the shape of window-reduced operand: " + "source(%s), window-reduced operand(%s)", + ShapeUtil::HumanString(source_shape).c_str(), + ShapeUtil::HumanString(window_result_shape).c_str()); + } + return operand_shape; +} + +/* static */ StatusOr ShapeInference::InferSliceShape( + const Shape& arg, tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice limits) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); + VLOG(2) << tensorflow::strings::Printf( + "slicing shape %s starts={%s} limits={%s}", + ShapeUtil::HumanString(arg).c_str(), + tensorflow::str_util::Join(starts, ", ").c_str(), + tensorflow::str_util::Join(limits, ", ").c_str()); + + if (starts.size() != limits.size()) { + return InvalidArgument("slice start and limit sizes differ: %zu vs %zu", + starts.size(), limits.size()); + } + + if (starts.size() != ShapeUtil::Rank(arg)) { + return InvalidArgument( + "slice index count does not match argument rank: %zu vs %lld", + starts.size(), ShapeUtil::Rank(arg)); + } + + std::vector sizes; + for (int64 dimension = 0; dimension < starts.size(); ++dimension) { + int64 start_index = starts[dimension]; + int64 limit_index = limits[dimension]; + if (start_index < 0) { + return InvalidArgument("negative start index to slice: %lld", + start_index); + } + if (limit_index < 0) { + return InvalidArgument("negative limit index to slice: %lld", + limit_index); + } + if (limit_index > arg.dimensions(dimension)) { + return InvalidArgument( + "limit index (%lld) must be less than or equal to dimension " + "size (%lld)", + limit_index, arg.dimensions(dimension)); + } + if (start_index > limit_index) { + return InvalidArgument( + "limit index (%lld) must be greater or equal to " + "start index (%lld) in slice", + limit_index, start_index); + } + VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, + start_index); + VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, + limit_index); + + sizes.push_back(limits[dimension] - starts[dimension]); + } + + return ShapeUtil::MakeShape(arg.element_type(), sizes); +} + +/* static */ StatusOr ShapeInference::InferDynamicSliceShape( + const Shape& operand_shape, const Shape& start_indices_shape, + tensorflow::gtl::ArraySlice slice_sizes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape, + "start indices of dynamic slice")); + + VLOG(2) << tensorflow::strings::Printf( + "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape).c_str(), + ShapeUtil::HumanString(start_indices_shape).c_str(), + tensorflow::str_util::Join(slice_sizes, ", ").c_str()); + + if (ShapeUtil::Rank(start_indices_shape) != 1) { + return InvalidArgument( + "dynamic slice start indices of rank %lld must be rank1.", + ShapeUtil::Rank(start_indices_shape)); + } + + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "dynamic slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (ShapeUtil::Rank(operand_shape) != start_num_dims) { + return InvalidArgument( + "dynamic slice start number of dimensions %lld must match rank %lld of " + "slice input", + start_num_dims, ShapeUtil::Rank(operand_shape)); + } + + if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "dynamic slice index count does not match argument rank: %zu vs %lld", + slice_sizes.size(), ShapeUtil::Rank(operand_shape)); + } + + for (int64 dim = 0; dim < slice_sizes.size(); ++dim) { + const int64 input_dim_size = operand_shape.dimensions(dim); + const int64 slice_dim_size = slice_sizes[dim]; + if (slice_dim_size <= 0) { + return InvalidArgument("negative size index to dynamic slice: %lld", + slice_dim_size); + } + if (slice_dim_size > input_dim_size) { + return InvalidArgument( + "slice dim size %lld greater than dynamic slice dimension: %lld", + slice_dim_size, input_dim_size); + } + VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, + slice_dim_size); + } + + return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); +} + +/* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( + const Shape& operand_shape, const Shape& update_shape, + const Shape& start_indices_shape) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + start_indices_shape, "start indices of dynamic update slice")); + + VLOG(2) << tensorflow::strings::Printf( + "updating slice of shape %s at dynamic start_indices %s with update " + "shape %s", + ShapeUtil::HumanString(operand_shape).c_str(), + ShapeUtil::HumanString(start_indices_shape).c_str(), + ShapeUtil::HumanString(update_shape).c_str()); + + if (ShapeUtil::Rank(start_indices_shape) != 1) { + return InvalidArgument( + "dynamic update slice start indices of rank %lld must be rank1.", + ShapeUtil::Rank(start_indices_shape)); + } + + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "dynamic update slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (ShapeUtil::Rank(operand_shape) != start_num_dims) { + return InvalidArgument( + "dynamic update slice start number of dimensions %lld must match " + "rank %lld of slice input", + start_num_dims, ShapeUtil::Rank(operand_shape)); + } + + if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "dynamic update slice update rank does not match argument rank: " + "%lld vs %lld", + ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); + } + + if (operand_shape.element_type() != update_shape.element_type()) { + return InvalidArgument( + "dynamic update slice update element type does not match argument. " + "operand.element_type: %s vs update.element_type: %s", + PrimitiveType_Name(operand_shape.element_type()).c_str(), + PrimitiveType_Name(update_shape.element_type()).c_str()); + } + + for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { + const int64 input_dim_size = operand_shape.dimensions(dim); + const int64 update_dim_size = update_shape.dimensions(dim); + if (update_dim_size <= 0) { + return InvalidArgument( + "size index %lld to dynamic update slice must be > 0", + update_dim_size); + } + if (update_dim_size > input_dim_size) { + return InvalidArgument( + "update dim size %lld greater than dynamic slice dimension: %lld", + update_dim_size, input_dim_size); + } + VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, + update_dim_size); + } + + return operand_shape; +} + +/*static */ StatusOr ShapeInference::InferReverseShape( + const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of reverse")); + if (!AllUnique(dimensions)) { + return InvalidArgument("a dimension number is duplicated in reverse"); + } + for (int64 dimension : dimensions) { + if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { + return InvalidArgument( + "one of the reverse dimensions (%lld) is out-of-bounds in shape %s", + dimension, ShapeUtil::HumanString(operand_shape).c_str()); + } + } + return operand_shape; +} + +/* static */ StatusOr ShapeInference::InferGetTupleElementShape( + const Shape& arg, int64 index) { + if (!ShapeUtil::IsTuple(arg)) { + return InvalidArgument( + "cannot infer shape: attempting to index into non-tuple: %s", + ShapeUtil::HumanString(arg).c_str()); + } + + if (index >= arg.tuple_shapes_size()) { + return InvalidArgument( + "cannot infer shape: attempt to index out of tuple bounds: %lld " + ">= %d in shape %s", + index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); + } + + return arg.tuple_shapes(index); +} + +/* static */ StatusOr ShapeInference::InferWhileShape( + const ProgramShape& condition, const ProgramShape& body, + const Shape& init) { + // Check the number of parameters for given computations. + if (condition.parameters_size() != 1) { + return InvalidArgument("condition must take 1 arguments; got %d", + condition.parameters_size()); + } + if (body.parameters_size() != 1) { + return InvalidArgument("body must take 1 arguments; got %d", + body.parameters_size()); + } + + string shape_string = tensorflow::strings::Printf( + "condition: %s; body: %s; init: %s", condition.ShortDebugString().c_str(), + body.ShortDebugString().c_str(), init.ShortDebugString().c_str()); + + // Check the shapes of computation parameters and return types. + if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { + return InvalidArgument("condition must return a boolean; got %s", + shape_string.c_str()); + } + if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || + !ShapeUtil::Compatible(body.result(), body.parameters(0)) || + !ShapeUtil::Compatible(body.result(), init)) { + return InvalidArgument( + "the parameter of condition and body, the result of the body, and init " + "must all have the same shape; got %s", + shape_string.c_str()); + } + + return init; +} + +/* static */ StatusOr ShapeInference::InferBroadcastShape( + const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); + for (int64 size : broadcast_sizes) { + if (size < 0) { + return InvalidArgument("Broadcast with negative dimension size %lld.", + size); + } + } + + std::vector dimensions(operand.dimensions_size() + + broadcast_sizes.size()); + std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin()); + std::copy(operand.dimensions().begin(), operand.dimensions().end(), + dimensions.begin() + broadcast_sizes.size()); + return ShapeUtil::MakeShape(operand.element_type(), dimensions); +} + +/* static */ StatusOr ShapeInference::InferReshapeShape( + const Shape& operand, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape")); + + Shape inferred_shape = + ShapeUtil::MakeShape(operand.element_type(), new_sizes); + + if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { + return InvalidArgument( + "reshape operation has mismatched element counts: from=%lld to=%lld", + ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape)); + } + + std::vector indices(ShapeUtil::Rank(operand)); + std::iota(indices.begin(), indices.end(), 0); + if (dimensions.size() != ShapeUtil::Rank(operand) || + !std::is_permutation(dimensions.begin(), dimensions.end(), + indices.begin())) { + return InvalidArgument( + "Reshape dimensions not a permutation of the operand dimensions."); + } + + return inferred_shape; +} + +/* static */ StatusOr ShapeInference::InferTransposeShape( + const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose")); + + std::vector indices(ShapeUtil::Rank(operand)); + std::iota(indices.begin(), indices.end(), 0); + if (dimensions.size() != ShapeUtil::Rank(operand) || + !std::is_permutation(dimensions.begin(), dimensions.end(), + indices.begin())) { + return InvalidArgument( + "Transpose dimensions not a permutation of the operand dimensions."); + } + + // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, + // we need output[i]=input[dimensions[i]] which is + // Permute(Inverse(dimensions),input). + return ShapeUtil::MakeShape(operand.element_type(), + Permute(InversePermutation(dimensions), + AsInt64Slice(operand.dimensions()))); +} + +/* static */ StatusOr ShapeInference::InferSelectShape( + const Shape& pred, const Shape& on_true, const Shape& on_false) { + if (!ShapeUtil::Compatible(on_true, on_false)) { + return InvalidArgument( + "operands to select must be the same shape; got %s and %s", + ShapeUtil::HumanString(on_true).c_str(), + ShapeUtil::HumanString(on_false).c_str()); + } + if (pred.element_type() != PRED) { + return InvalidArgument( + "select's pred operand must have PRED element type; got %s", + ShapeUtil::HumanString(pred).c_str()); + } + if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) { + // By this stage we know that pred's element type is PRED. Therefore, this + // check restricts pred to be a PRED scalar, or a PRED array with the same + // dimensions as on_true and on_false. + return on_true; + } else { + return Unimplemented( + "select operation with non-scalar predicate with dimensionality " + " different from the other operands: %s", + ShapeUtil::HumanString(pred).c_str()); + } +} + +/* static */ StatusOr ShapeInference::InferCallShape( + tensorflow::gtl::ArraySlice arg_shapes, + const ProgramShape& to_apply) { + // The applied function's arity equals the number of arguments. + if (arg_shapes.size() != to_apply.parameters_size()) { + return InvalidArgument( + "Call applied function arity must match number of arguments; got: " + "arity: %d, arguments: %zu", + to_apply.parameters_size(), arg_shapes.size()); + } + + // All arguments must be compatible with the program shape. + for (int i = 0; i < arg_shapes.size(); ++i) { + const Shape& arg_shape = *arg_shapes[i]; + const Shape& param_shape = to_apply.parameters(i); + if (!ShapeUtil::Compatible(arg_shape, param_shape)) { + return InvalidArgument( + "Call parameter must match argument; got parameter %d shape: %s, " + "argument shape: %s", + i, ShapeUtil::HumanString(param_shape).c_str(), + ShapeUtil::HumanString(arg_shape).c_str()); + } + } + + return to_apply.result(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h new file mode 100644 index 0000000000..ced2f4d001 --- /dev/null +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -0,0 +1,219 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Shape inference is used by the XLA service as the user builds up +// computation requests. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// For a given operation and input shapes, infers what the resulting shape is +// for the operation. With this functionality, the user does not need to +// specify the expected result type for computations that are built up via the +// API -- the shape that results from an operation is inferred. +class ShapeInference { + public: + // Infers the shape produced by applying the given unary operation to the + // given input shape. + static StatusOr InferUnaryOpShape(UnaryOperation operation, + const Shape& arg); + + // Infers the shape produced by applying the given binary operation to the + // given input shapes. + static StatusOr InferBinaryOpShape( + BinaryOperation operation, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Infers the shape produced by applying the given ternary operation to the + // given input shapes. + static StatusOr InferTernaryOpShape(TernaryOperation operation, + const Shape& lhs, const Shape& rhs, + const Shape& ehs); + + // Infers the shape produced by applying the given variadic operation to the + // given input operand shapes. + static StatusOr InferVariadicOpShape( + VariadicOperation operation, std::vector operand_shapes); + + // Infers the shape produced by applying the given mapping computation shape + // to the given operand shapes. + static StatusOr InferMapShape( + tensorflow::gtl::ArraySlice arg_shapes, + const ProgramShape& to_apply); + + // Infers the shape produced by applying the given convolutional + // filter (rhs) to lhs in the way specified by the fields on window. + static StatusOr InferConvolveShape( + const Shape& lhs, const Shape& rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Infers the shape produced a cross replica sum with the given operand shape. + static StatusOr InferCrossReplicaSumShape(const Shape& operand); + + // Infers the shape produced by applying the given reduction computation + // shape to the given input operand shape. + // + // If pass_index is true, the reduce function is invoked with the element + // index as the leading parameter, and the program shape should match + // accordingly (or an error will result). + static StatusOr InferReduceShape( + const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + const ProgramShape& to_apply); + + // Infers the shape produced by applying the given computation to the operand + // shape with the given window and stride dimensions. + static StatusOr InferReduceWindowShape( + const Shape& operand_shape, const Shape& init_value, const Window& window, + const ProgramShape& to_apply_shape); + + // Infers the shape produced by scattering the given source shape to the + // selected indices of each window on the operand shape. + static StatusOr InferSelectAndScatterShape( + const Shape& operand_shape, const ProgramShape& select_shape, + const Window& window, const Shape& source_shape, + const Shape& init_value_shape, const ProgramShape& scatter_shape); + + // Infers the shape produced by a reverse operation that reverses the order + // of the elements in the given dimensions. + static StatusOr InferReverseShape( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimensions); + + // Infers the shape produced by a slice operation spanning from the starts to + // the limits in the original shape's dimensions. + // + // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] + static StatusOr InferSliceShape( + const Shape& arg, tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice limits); + + // Infers the shape produced by a dynamic slice operation of size specified + // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. + static StatusOr InferDynamicSliceShape( + const Shape& operand_shape, const Shape& start_indices_shape, + tensorflow::gtl::ArraySlice slice_sizes); + + // Infers the shape produced by a dynamic update slice operation based + // on the shape of operand and update. + static StatusOr InferDynamicUpdateSliceShape( + const Shape& operand_shape, const Shape& update_shape, + const Shape& start_indices_shape); + + // Infers the shape produced by doing a compile-time-constant indexing into + // the given input shape. This is essential for operations on tuples, because + // it is impossible to infer the type that comes out of the tuple indexing if + // it is not a compile time constant. + static StatusOr InferGetTupleElementShape(const Shape& arg, + int64 index); + + // Infers the shape produced from a while node. condition and body are the + // shapes of computations for the condition and the body of a while node, and + // init is the shape of data initially passed in to the body as an argument. + // The shapes must match; condition: T -> PRED, body: T -> T, init: T + static StatusOr InferWhileShape(const ProgramShape& condition, + const ProgramShape& body, + const Shape& init); + + // Infers the shape produced by a broadcast operation. + static StatusOr InferBroadcastShape( + const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + + // Infers the shape produced by a reshape operation from the element type of + // its operand and the new dimension sizes specified. + static StatusOr InferReshapeShape( + const Shape& operand, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Infers the shape produced by a transpose operation from the element type of + // its operand and its dimensions field. + static StatusOr InferTransposeShape( + const Shape& operand, tensorflow::gtl::ArraySlice dimensions); + + // Helper that infers the shape produced by performing a concatenate operation + // with the given operand shapes. + static StatusOr InferConcatOpShape( + tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + + // Helper that validates the given operand shape can be converted to the + // target output_shape via a convert instruction -- the requirement is that + // the shape is identical except for the element type. + static StatusOr InferConvertShape(const Shape& operand_shape, + PrimitiveType new_element_type); + + // Helper that infers the shape produced by a pad operation based on the + // padding configuration. + static StatusOr InferPadShape(const Shape& operand_shape, + const Shape& padding_value_shape, + const PaddingConfig& padding_config); + + // Helper that validates the given arg_shapes are compatible with the shape of + // the to_apply parameters, and returns the to_apply result shape. + static StatusOr InferCallShape( + tensorflow::gtl::ArraySlice arg_shapes, + const ProgramShape& to_apply); + + private: + // Helper that infers the shape produced by performing a dot operation with + // the given LHS and RHS shapes. + static StatusOr InferDotOpShape(const Shape& lhs, const Shape& rhs); + + // Helper that infers the shape produced by performing an element-wise binary + // operation with the given LHS and RHS shapes. + // Note: By "element-wise" we mean operations that look at a single element in + // the LHS and a single element in the RHS to produce a single output element, + // even in the presence of broadcasting of one of the operands over the other. + static StatusOr InferElementwiseBinaryOpShape( + BinaryOperation operation, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Helper for inferring the shape of Select ops. + static StatusOr InferSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); + + // Helper for inferring shapes of binary operations which use degenerate + // dimension broadcasting (a dimension of size 1 in one operand is broadcast + // up to match the size of the dimension in the other operand). + static StatusOr InferDegenerateDimensionBroadcastShape( + BinaryOperation operation, const Shape& lhs, const Shape& rhs); + + // Helper for inferring shapes of binary operations using "InDim" + // broadcasting. This is the broadcasting used in the *InDim binary operations + // (for example ComputationBuilder::AddInDim). smaller_shape must be a + // lower-rank shape than larger_shape. Returns the shape that the + // smaller_shape is broadcast to. + static StatusOr InferInDimBroadcastShape( + BinaryOperation operation, const Shape& smaller_shape, + const Shape& larger_shape, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc new file mode 100644 index 0000000000..10fd4e53c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -0,0 +1,1133 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shape_inference.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +class ShapeInferenceTest : public ::testing::Test { + protected: + // Some handy scalar shapes. + const Shape s32_ = ShapeUtil::MakeShape(S32, {}); + const Shape f32_ = ShapeUtil::MakeShape(F32, {}); + const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); + + // Some handy vector and matrix shapes of F32 type. + // Suffix: vector_length_, matrix_rows_cols_ + const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32}); + const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64}); + const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48}); + const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64}); + const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48}); + + // Some handy S32 arrays. + const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64}); +}; + +// Subclass for testing InferReduceShape. +class ReduceShapeInferenceTest : public ShapeInferenceTest { + protected: + // Helper that runs reduce shape inference with the input 'arg' and given + // dimensions to reduce, and checks the inferred shape is as expected. The + // element type here is hard-coded to F32. + void ExpectInferredReduceShape( + const Shape& expected_inferred_shape, const Shape& arg, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + arg, f32_, dimensions_to_reduce, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, + inferred_status.ValueOrDie())); + } +}; + +// Subclass for testing InferSelectAndScatterShape. +class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { + protected: + SelectAndScatterShapeInferenceTest() { + operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16}); + source_shape_ = ShapeUtil::MakeShape(F32, {4, 8}); + WindowDimension dim; + dim.set_size(2); + dim.set_stride(2); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window_.add_dimensions() = dim; + *window_.add_dimensions() = dim; + init_value_shape_ = ShapeUtil::MakeShape(F32, {}); + select_program_shape_ = ShapeUtil::MakeProgramShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); + scatter_program_shape_ = ShapeUtil::MakeProgramShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); + } + + Shape operand_shape_; + Shape source_shape_; + Window window_; + Shape init_value_shape_; + ProgramShape select_program_shape_; + ProgramShape scatter_program_shape_; +}; + +TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = ShapeInference::InferUnaryOpShape( + UnaryOperation::UNOP_NEGATE, matrix_shape); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { + Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { + auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, SelectBadShapes) { + auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + ASSERT_FALSE(inferred_status_error1.ok()); + ASSERT_MATCH( + inferred_status_error1.status().error_message(), + testing::ContainsRegex("operands to select must be the same shape")); + + auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + ASSERT_FALSE(inferred_status_error2.ok()); + ASSERT_MATCH(inferred_status_error2.status().error_message(), + testing::ContainsRegex("pred operand must have PRED")); + + auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), + matrix_64_48_, matrix_64_48_); + ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_MATCH( + inferred_status_error3.status().error_message(), + testing::ContainsRegex("with non-scalar predicate with dimensionality")); + + // Tuples have a TUPLE element type and cannot be the pred of a select. + auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + ShapeUtil::MakeTupleShape({f32_, f32_}), + ShapeUtil::MakeTupleShape({f32_, f32_})); + ASSERT_FALSE(inferred_status_error4.ok()); + ASSERT_MATCH( + inferred_status_error4.status().error_message(), + testing::ContainsRegex("pred operand must have PRED element type")); +} + +TEST_F(ShapeInferenceTest, VariadicOpTuplify) { + StatusOr result = ShapeInference::InferVariadicOpShape( + VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + ASSERT_IS_OK(result.status()); + ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), + ShapeUtil::MakeTupleShape({s32_, f32_}))); +} + +TEST_F(ShapeInferenceTest, ReduceWindowInHalf) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8}); + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(2); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape init_value_shape = ShapeUtil::MakeShape(F32, {}); + Shape float_scalar = ShapeUtil::MakeShape(F32, {}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); + auto inferred_status = ShapeInference::InferReduceWindowShape( + matrix_shape, init_value_shape, window, to_apply); + + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred)); +} + +TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) { + auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_IS_OK(inferred_status_ok.status()); + Shape inferred = inferred_status_ok.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred)); +} + +TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { + Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6}); + auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_, window_, source_shape_fail, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_status_fail.ok()); + ASSERT_MATCH(inferred_status_fail.status().error_message(), + testing::ContainsRegex("source shape does not match")); +} + +TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { + ProgramShape select_program_shape_fail = + ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_); + auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_status_fail.ok()); + ASSERT_MATCH( + inferred_status_fail.status().error_message(), + testing::ContainsRegex("select function must take 2 parameters")); +} + +TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { + ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); + auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_status_fail.ok()); + ASSERT_MATCH(inferred_status_fail.status().error_message(), + testing::ContainsRegex("select function must have rank-0 PRED")); +} + +TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { + ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); + auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_status_fail.ok()); + ASSERT_MATCH(inferred_status_fail.status().error_message(), + testing::ContainsRegex("select function's first parameter")); +} + +TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { + ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_); + auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_status_fail.ok()); + ASSERT_MATCH(inferred_status_fail.status().error_message(), + testing::ContainsRegex("select function's second parameter")); +} + +TEST_F(ShapeInferenceTest, Convolve) { + ConvolutionDimensionNumbers dnums; + + // Dimension order: batch, feature, x0, x1 + Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); + dnums.set_batch_dimension(0); + dnums.set_feature_dimension(1); + dnums.add_spatial_dimensions(2); + dnums.add_spatial_dimensions(3); + + // Dimension order: x1, batch, feature, x0 + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(0); + + Window window; + auto dim0 = window.add_dimensions(); + auto dim1 = window.add_dimensions(); + dim0->set_size(3); + dim0->set_stride(2); + dim0->set_padding_low(1); + dim0->set_padding_high(1); + dim0->set_window_dilation(1); + dim0->set_base_dilation(1); + dim1->set_size(2); + dim1->set_stride(1); + dim1->set_padding_low(0); + dim1->set_padding_high(0); + dim1->set_window_dilation(1); + dim1->set_base_dilation(1); + auto inferred_status = + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred_shape = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { + ConvolutionDimensionNumbers dnums; + + // Dimension order: batch, feature, x0, x1 + Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); + dnums.set_batch_dimension(0); + dnums.set_feature_dimension(1); + dnums.add_spatial_dimensions(2); + dnums.add_spatial_dimensions(3); + + // Dimension order: x1, batch, feature, x0 + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(0); + + Window window; + auto dim0 = window.add_dimensions(); + dim0->set_size(3); + dim0->set_stride(3); + dim0->set_padding_low(0); + dim0->set_padding_high(0); + dim0->set_window_dilation(6); + dim0->set_base_dilation(1); + + auto dim1 = window.add_dimensions(); + dim1->set_size(2); + dim1->set_stride(1); + dim1->set_padding_low(2); + dim1->set_padding_high(1); + dim1->set_window_dilation(2); + dim1->set_base_dilation(1); + auto inferred_status = + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred_shape = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { + ConvolutionDimensionNumbers dnums; + + // Dimension order: batch, feature, x0, x1 + Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); + dnums.set_batch_dimension(0); + dnums.set_feature_dimension(1); + dnums.add_spatial_dimensions(2); + dnums.add_spatial_dimensions(3); + + // Dimension order: x1, batch, feature, x0 + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(0); + + Window window; + auto dim0 = window.add_dimensions(); + dim0->set_size(4); + dim0->set_stride(3); + dim0->set_padding_low(0); + dim0->set_padding_high(0); + dim0->set_window_dilation(1); + dim0->set_base_dilation(6); + + auto dim1 = window.add_dimensions(); + dim1->set_size(2); + dim1->set_stride(1); + dim1->set_padding_low(2); + dim1->set_padding_high(1); + dim1->set_window_dilation(1); + dim1->set_base_dilation(2); + auto inferred_status = + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred_shape = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), + inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { + // Dimension order for this test: batch, feature, x0, x1 + Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(3); + dnums.set_feature_dimension(2); + dnums.add_spatial_dimensions(0); + dnums.add_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + + Window window; + auto dim0 = window.add_dimensions(); + auto dim1 = window.add_dimensions(); + dim0->set_size(2); + dim0->set_stride(1); + dim0->set_padding_low(0); + dim0->set_padding_high(0); + dim1->set_size(3); + dim1->set_stride(2); + dim1->set_padding_low(1); + dim1->set_padding_high(1); + auto inferred_status = + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_MATCH(inferred_status.status().error_message(), + testing::ContainsRegex("each dimension exactly once")); +} + +TEST_F(ShapeInferenceTest, MapThatChangesElementType) { + Shape arg = ShapeUtil::MakeShape(F32, {20}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + Shape expected = ShapeUtil::MakeShape(S32, {20}); + EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, Map) { + auto inferred_status_r1f32 = ShapeInference::InferMapShape( + {&vector_32_, &vector_32_}, + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + EXPECT_IS_OK(inferred_status_r1f32.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie())); + + // It's OK to provide a single argument, as long as the applied arity matches + // (this degenerates to a Map). + auto inferred_status_r1f32_one = ShapeInference::InferMapShape( + {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + EXPECT_IS_OK(inferred_status_r1f32_one.status()); + EXPECT_TRUE( + ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie())); + + auto inferred_status_r2s32 = ShapeInference::InferMapShape( + {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, + ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_)); + EXPECT_IS_OK(inferred_status_r2s32.status()); + EXPECT_TRUE( + ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie())); + + auto no_args_error = ShapeInference::InferMapShape( + {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ASSERT_FALSE(no_args_error.ok()); + ASSERT_MATCH(no_args_error.status().error_message(), + testing::ContainsRegex("expects at least one argument")); + + auto args_diff_shapes_error = ShapeInference::InferMapShape( + {&vector_32_, &vector_64_}, + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ASSERT_FALSE(args_diff_shapes_error.ok()); + ASSERT_MATCH( + args_diff_shapes_error.status().error_message(), + testing::ContainsRegex("requires all operands to have the same shape")); + + auto arity_error = ShapeInference::InferMapShape( + {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + ASSERT_FALSE(arity_error.ok()); + ASSERT_MATCH(arity_error.status().error_message(), + testing::ContainsRegex("function arity must match")); + + auto output_shape_error = ShapeInference::InferMapShape( + {&vector_32_, &vector_32_}, + ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); + ASSERT_FALSE(output_shape_error.ok()); + ASSERT_MATCH(output_shape_error.status().error_message(), + testing::ContainsRegex("result has to be a scalar")); + + auto param_shape_error = ShapeInference::InferMapShape( + {&vector_32_, &vector_32_}, + ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); + ASSERT_FALSE(param_shape_error.ok()); + ASSERT_MATCH(param_shape_error.status().error_message(), + testing::ContainsRegex("parameter has to be a scalar")); + + auto param_element_type_error = ShapeInference::InferMapShape( + {&vector_32_, &vector_32_}, + ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); + ASSERT_FALSE(param_element_type_error.ok()); + ASSERT_MATCH(param_element_type_error.status().error_message(), + testing::ContainsRegex("parameter type has to match argument")); + + Shape arg = ShapeUtil::MakeShape(F32, {20}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie())); + + auto inferred_status_error1 = ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ASSERT_FALSE(inferred_status_error1.ok()); + ASSERT_MATCH(inferred_status_error1.status().error_message(), + testing::ContainsRegex("arity must match number of arguments")); + + auto inferred_status_error2 = ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + ASSERT_FALSE(inferred_status_error2.ok()); + ASSERT_MATCH(inferred_status_error2.status().error_message(), + testing::ContainsRegex("has to be a scalar")); + + auto inferred_status_error3 = ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); + ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_MATCH(inferred_status_error3.status().error_message(), + testing::ContainsRegex("has to be a scalar")); + + auto inferred_status_error5 = ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); + ASSERT_FALSE(inferred_status_error5.ok()); + ASSERT_MATCH(inferred_status_error5.status().error_message(), + testing::ContainsRegex("parameter type has to match argument")); +} + +TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { + ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}), + /*dimensions_to_reduce=*/{0}); +} + +TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) { + ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}), + ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{0}); +} + +TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) { + ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}), + ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{1}); +} + +TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) { + ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}), + ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{0, 1}); +} + +TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) { + ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}), + ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{1, 2}); +} + +TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) { + ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{0, 2}); + + // Check that the order of dimensions_to_reduce doesn't matter. + ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{2, 0}); +} + +TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { + ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}), + /*dimensions_to_reduce=*/{0, 1, 2}); +} + +TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, + to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_MATCH(inferred_status.status().error_message(), + testing::ContainsRegex("out-of-bounds dimension")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + auto inferred_status = + ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + /*dimensions_to_reduce=*/{0}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_MATCH(inferred_status.status().error_message(), + testing::ContainsRegex("take 2 parameters")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + auto inferred_status = + ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + /*dimensions_to_reduce=*/{0}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_MATCH(inferred_status.status().error_message(), + testing::ContainsRegex("first parameter shape differs")); +} + +TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred)); +} + +TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, + inferred_status.status().code()); +} + +TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { + Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); + auto inferred_status = + ShapeInference::InferSliceShape(vector_shape, {2}, {4}); + ASSERT_TRUE(inferred_status.ok()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2}))); +} + +TEST_F(ShapeInferenceTest, InferConstIndexShape) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); + auto inferred0_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, 0); + auto inferred1_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, 1); + ASSERT_IS_OK(inferred0_status.status()); + ASSERT_IS_OK(inferred1_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie())); + ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferPowShape) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferCompareShapeEq) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), + inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferCompareShapeGe) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), + inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferCompareShapeGt) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), + inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferCompareShapeLe) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), + inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferCompareShapeLt) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), + inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, InferCompareShapeNe) { + auto ten_floats = ShapeUtil::MakeShape(F32, {10}); + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), + inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, BroadcastScalar) { + for (auto element_type : {F32, U32, S8}) { + const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {}); + { // no-op scalar broadcast + auto status = ShapeInference::InferBroadcastShape(scalar_shape, {}); + ASSERT_IS_OK(status.status()); + ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie())); + } + const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3}); + { // scalar -> 1d broadcast + auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3}); + ASSERT_IS_OK(status.status()); + ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie())); + } + { // no-op 1d broadcast + auto status = ShapeInference::InferBroadcastShape(oned_shape, {}); + ASSERT_IS_OK(status.status()); + ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie())); + } + const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3}); + { // scalar -> 2d broadcast + auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3}); + ASSERT_IS_OK(status.status()); + ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie())); + } + { // 1d -> 2d broadcast + auto status = ShapeInference::InferBroadcastShape(oned_shape, {2}); + ASSERT_IS_OK(status.status()); + ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie())); + } + } +} + +// scalar vector: error +TEST_F(ShapeInferenceTest, ScalarDotVector) { + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_MATCH(inferred_status.status().error_message(), + testing::ContainsRegex("dot only supports rank")); +} + +// 3D 2D: error +TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { + auto inferred_status = ShapeInference::InferBinaryOpShape( + BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_MATCH(inferred_status.status().error_message(), + testing::ContainsRegex("dot only supports rank")); +} + +// vector vector -> scalar +TEST_F(ShapeInferenceTest, VectorDotVector) { + auto inferred_status = + ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ASSERT_FALSE(inferred_status_mismatch.ok()); +} + +// matrix vector -> vector +TEST_F(ShapeInferenceTest, MatrixDotVector) { + auto inferred_status = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); + auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + ASSERT_FALSE(inferred_status_mismatch.ok()); +} + +// vector matrix -> vector +TEST_F(ShapeInferenceTest, VectorDotMatrix) { + auto inferred_status = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); + auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + ASSERT_FALSE(inferred_status_mismatch.ok()); +} + +// matrix matrix -> matrix +TEST_F(ShapeInferenceTest, MatrixDotMatrix) { + auto inferred_status_match = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(matrix_64_48_); + auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + ASSERT_FALSE(inferred_status_mismatch.ok()); +} + +TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { + // Test variations of broadcasting a vector for a binary add with a + // matrix. + const Shape mat = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); + const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); + + auto inferred_status_match = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, mat, vec8, {1}); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); + + auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, mat, vec8, {0}); + ASSERT_FALSE(inferred_status_mismatch.ok()); + + inferred_status_match = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, mat, vec16, {0}); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); + + inferred_status_mismatch = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, mat, vec16, {1}); + ASSERT_FALSE(inferred_status_mismatch.ok()); +} + +TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { + // Test variations of broadcasting a matrix for a binary add with a cube. + const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4}); + const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4}); + const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4}); + const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); + + auto inferred_status_match = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); + + inferred_status_match = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); + + inferred_status_match = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); +} + +TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { + // Test various errors with the broadcast argument. + const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4}); + const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8}); + const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); + const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4}); + const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); + + // "magical" broadcast rejected + auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor, vec8, {}); + ASSERT_FALSE(inferred_status_error1.ok()); + ASSERT_MATCH(inferred_status_error1.status().error_message(), + testing::ContainsRegex("automatic")); + + // broadcast_dimension out of bounds for tensor's rank + auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + ASSERT_FALSE(inferred_status_error2.ok()); + ASSERT_MATCH( + inferred_status_error2.status().error_message(), + testing::ContainsRegex("broadcast dimension number .* too large")); + + // broadcast_dimension doesn't match corresponding dimension + auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_MATCH(inferred_status_error3.status().error_message(), + testing::ContainsRegex("broadcast dimension 0 mismatch")); + + // broadcast_dimensions list too long + auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + ASSERT_FALSE(inferred_status_error4.ok()); + ASSERT_MATCH( + inferred_status_error4.status().error_message(), + testing::ContainsRegex("size of broadcast_dimensions has to match")); + + // there's a dimension above the rank of the tensor + auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + ASSERT_FALSE(inferred_status_error5.ok()); + ASSERT_MATCH( + inferred_status_error5.status().error_message(), + testing::ContainsRegex("broadcast dimension number .* too large")); + + // broadcasting dimensions don't match in this order + auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + ASSERT_FALSE(inferred_status_error6.ok()); + ASSERT_MATCH(inferred_status_error6.status().error_message(), + testing::ContainsRegex("broadcast dimension 0 mismatch")); + + // The following two tests make sure that broadcasting dimensions are listed + // in a proper (strictly increasing) order, even if the lower-rank array + // matches the higher-rank array in many different ways. + auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + ASSERT_FALSE(inferred_status_error7.ok()); + ASSERT_MATCH(inferred_status_error7.status().error_message(), + testing::ContainsRegex("broadcast dimensions order is wrong")); + + auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( + BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + ASSERT_FALSE(inferred_status_error8.ok()); + ASSERT_MATCH(inferred_status_error8.status().error_message(), + testing::ContainsRegex("broadcast dimensions order is wrong")); +} + +// Tests for the while instruction with proper shapes. +TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) { + Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); + ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); + ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); + auto inferred_status = + ShapeInference::InferWhileShape(cond, body, result_shape); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred)); +} + +// Tests for the while instruction with wrong shapes. +TEST_F(ShapeInferenceTest, WhileWithBadShapes) { + Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); + ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); + ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); + + auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_); + auto inferred_status_error1 = + ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); + ASSERT_FALSE(inferred_status_error1.ok()); + ASSERT_MATCH(inferred_status_error1.status().error_message(), + testing::ContainsRegex("condition must take 1 arguments")); + + auto bad_shape_2 = + ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); + auto inferred_status_error2 = + ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); + ASSERT_FALSE(inferred_status_error2.ok()); + ASSERT_MATCH(inferred_status_error2.status().error_message(), + testing::ContainsRegex("body must take 1 arguments")); + + auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); + auto inferred_status_error3 = + ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); + ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_MATCH(inferred_status_error3.status().error_message(), + testing::ContainsRegex("condition must return a boolean")); + + auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); + auto inferred_status_error4 = + ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); + ASSERT_FALSE(inferred_status_error4.ok()); + ASSERT_MATCH(inferred_status_error4.status().error_message(), + testing::ContainsRegex("parameter of condition and body")); +} + +// Tests for the concatenate instruction with proper shapes. +TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) { + auto inferred_status_1 = ShapeInference::InferConcatOpShape( + {&vector_32_, &vector_64_}, /*dimension=*/0); + ASSERT_IS_OK(inferred_status_1.status()); + Shape inferred_1 = inferred_status_1.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1)); + + auto inferred_status_2 = ShapeInference::InferConcatOpShape( + {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0); + ASSERT_IS_OK(inferred_status_2.status()); + Shape inferred_2 = inferred_status_2.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2)); + + auto inferred_status_3 = ShapeInference::InferConcatOpShape( + {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1); + ASSERT_IS_OK(inferred_status_3.status()); + Shape inferred_3 = inferred_status_3.ValueOrDie(); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3)); +} + +// Tests for the concatenate instruction with wrong shapes. +TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { + auto inferred_status_error1 = + ShapeInference::InferConcatOpShape({}, /*dimension=*/0); + ASSERT_FALSE(inferred_status_error1.ok()); + ASSERT_MATCH( + inferred_status_error1.status().error_message(), + testing::ContainsRegex("Concatenate expects at least one argument")); + + auto inferred_status_error2 = + ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); + ASSERT_FALSE(inferred_status_error2.ok()); + ASSERT_MATCH(inferred_status_error2.status().error_message(), + testing::ContainsRegex( + "dimension to concatenate along out of bounds: -1")); + + auto inferred_status_error3 = + ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); + ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_MATCH(inferred_status_error3.status().error_message(), + testing::ContainsRegex( + "dimension to concatenate along out of bounds: 1")); + + Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); + auto inferred_status_error4 = ShapeInference::InferConcatOpShape( + {&vector_32_, &tuple}, /*dimension=*/0); + ASSERT_FALSE(inferred_status_error4.ok()); + ASSERT_MATCH( + inferred_status_error4.status().error_message(), + testing::ContainsRegex( + "Expected non-tuple argument for operand of concatenation.")); + + const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); + auto inferred_status_error5 = ShapeInference::InferConcatOpShape( + {&vector_32_, &vector_s32}, /*dimension=*/0); + ASSERT_FALSE(inferred_status_error5.ok()); + ASSERT_MATCH(inferred_status_error5.status().error_message(), + testing::ContainsRegex( + "cannot concatenate arrays with different element types")); + + auto inferred_status_error6 = ShapeInference::InferConcatOpShape( + {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); + ASSERT_FALSE(inferred_status_error6.ok()); + ASSERT_MATCH( + inferred_status_error6.status().error_message(), + testing::ContainsRegex("cannot concatenate arrays that differ in " + "dimensions other than the one being " + "concatenated")); +} + +TEST_F(ShapeInferenceTest, Pad) { + Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); + Shape padding_value_shape = ShapeUtil::MakeShape(F32, {}); + // Padding for dimension 0: {low: 0, high: 2, interior: 3} + // Padding for dimension 1: {low: 1, high: 5, interior: 0} + PaddingConfig padding_config; + auto dimension0 = padding_config.add_dimensions(); + dimension0->set_edge_padding_low(0); + dimension0->set_edge_padding_high(2); + dimension0->set_interior_padding(3); + auto dimension1 = padding_config.add_dimensions(); + dimension1->set_edge_padding_low(1); + dimension1->set_edge_padding_high(5); + dimension1->set_interior_padding(0); + + auto inferred_status = ShapeInference::InferPadShape( + input_shape, padding_value_shape, padding_config); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred_shape = inferred_status.ValueOrDie(); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); +} + +TEST_F(ShapeInferenceTest, Reverse) { + Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); + + auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1}); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred_shape = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape)); +} + +TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { + Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); + + auto inferred_status_error0 = + ShapeInference::InferReverseShape(input_shape, {0, 2}); + ASSERT_FALSE(inferred_status_error0.ok()); + ASSERT_MATCH(inferred_status_error0.status().error_message(), + testing::ContainsRegex("out-of-bounds")); + + auto inferred_status_error1 = + ShapeInference::InferReverseShape(input_shape, {0, -1}); + ASSERT_FALSE(inferred_status_error1.ok()); + ASSERT_MATCH(inferred_status_error1.status().error_message(), + testing::ContainsRegex("out-of-bounds")); + + auto inferred_status_error2 = + ShapeInference::InferReverseShape(input_shape, {0, 0}); + ASSERT_FALSE(inferred_status_error2.ok()); + ASSERT_MATCH(inferred_status_error2.status().error_message(), + testing::ContainsRegex("duplicated")); + + Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); + auto inferred_status_error3 = + ShapeInference::InferReverseShape(tuple_shape, {0}); + ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_MATCH(inferred_status_error3.status().error_message(), + testing::ContainsRegex("Expected non-tuple argument")); +} + +TEST_F(ShapeInferenceTest, Call) { + auto inferred_status0 = + ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_)); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferCallShape( + {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_}, + ShapeUtil::MakeProgramShape( + {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_)); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE( + ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferCallShape( + {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_MATCH(inferred_status_error0.status().error_message(), + testing::ContainsRegex("arity must match")); + + auto inferred_status_error1 = ShapeInference::InferCallShape( + {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_MATCH(inferred_status_error1.status().error_message(), + testing::ContainsRegex("arity must match")); + + auto inferred_status_error2 = ShapeInference::InferCallShape( + {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_MATCH(inferred_status_error2.status().error_message(), + testing::ContainsRegex("parameter must match argument")); +} + +TEST_F(ShapeInferenceTest, Transpose) { + Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); + auto inferred_shape_and_status = + ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0}); + EXPECT_IS_OK(inferred_shape_and_status); + Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, + ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc new file mode 100644 index 0000000000..cf49fd72b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +/* static */ StatusOr> +ShapedBuffer::MakeShapedBuffer(const Shape& shape, + const perftools::gputools::Platform* platform, + int device_ordinal) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return WrapUnique(new ShapedBuffer(shape, platform, device_ordinal)); +} + +/* static */ StatusOr> +ShapedBuffer::MakeArrayShapedBuffer( + const Shape& shape, const perftools::gputools::Platform* platform, + int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer) { + if (ShapeUtil::IsTuple(shape)) { + return InvalidArgument("Shape must be an array: %s", + ShapeUtil::HumanStringWithLayout(shape).c_str()); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr shaped_buffer, + MakeShapedBuffer(shape, platform, device_ordinal)); + *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) = + 0; + *shaped_buffer->mutable_buffers() = {buffer}; + return std::move(shaped_buffer); +} + +/* static */ StatusOr> +ShapedBuffer::MakeUnnestedTupleShapedBuffer( + const Shape& shape, const perftools::gputools::Platform* platform, + int device_ordinal, + const tensorflow::gtl::ArraySlice + buffers) { + if (!ShapeUtil::IsTuple(shape) || ShapeUtil::IsNestedTuple(shape)) { + return InvalidArgument("Shape must be an unnested tuple: %s", + ShapeUtil::HumanStringWithLayout(shape).c_str()); + } + if (buffers.size() != ShapeUtil::TupleElementCount(shape)) { + return InvalidArgument("Tuple has %lld elements, but %zu buffers given", + ShapeUtil::TupleElementCount(shape), buffers.size()); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr shaped_buffer, + MakeShapedBuffer(shape, platform, device_ordinal)); + TF_CHECK_OK(shaped_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [](const ShapeIndex& index, bool is_leaf, + size_t* buffer_element) -> tensorflow::Status { + if (is_leaf) { + CHECK_EQ(index.size(), 1); + *buffer_element = index[0]; + } + return tensorflow::Status::OK(); + })); + shaped_buffer->mutable_buffers()->reserve(buffers.size()); + for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) { + shaped_buffer->mutable_buffers()->push_back(memory_base); + } + return std::move(shaped_buffer); +} + +ShapedBuffer::ShapedBuffer(const Shape& shape, + const perftools::gputools::Platform* platform, + int device_ordinal) + : shape_(shape), + shape_index_to_buffer_entry_(shape), + platform_(platform), + device_ordinal_(device_ordinal) {} + +const perftools::gputools::DeviceMemoryBase& ShapedBuffer::buffer( + const ShapeIndex& index) const { + // Buffer are only set at the leaves (array elements of the shape). + CHECK(shape_index_to_buffer_entry_.IsLeaf(index)); + return buffers_[shape_index_to_buffer_entry_.element(index)]; +} + +perftools::gputools::DeviceMemoryBase* ShapedBuffer::mutable_buffer( + const ShapeIndex& index) { + // Buffer are only set at the leaves (array elements of the shape). + CHECK(shape_index_to_buffer_entry_.IsLeaf(index)); + return &buffers_[shape_index_to_buffer_entry_.element(index)]; +} + +/* static */ StatusOr> +ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + auto shaped_buffer = + WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); + + // Allocate an appropriate sized buffer for each array element in the shape. + TF_RETURN_IF_ERROR( + shaped_buffer->shape_index_to_buffer_entry_.ForEachMutableElement( + [&shaped_buffer](const ShapeIndex& index, bool is_leaf, + size_t* buffer_entry) -> tensorflow::Status { + if (is_leaf) { + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase memory_base, + shaped_buffer->allocator_->Allocate( + shaped_buffer->device_ordinal(), + ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape( + shaped_buffer->shape(), index)))); + shaped_buffer->buffers_.push_back(memory_base); + *buffer_entry = shaped_buffer->buffers_.size() - 1; + } + return tensorflow::Status::OK(); + })); + return std::move(shaped_buffer); +} + +ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) + : ShapedBuffer(shape, allocator->platform(), device_ordinal), + allocator_(allocator) {} + +ScopedShapedBuffer::~ScopedShapedBuffer() { + // Deallocate all non-null buffers. A buffer may appear in more than one spot + // in the shape (eg, a tuple with a repeated element) so keep track of what + // has been deallocated. + std::set deallocated_opaques; + for (perftools::gputools::DeviceMemoryBase& memory_base : buffers_) { + if (!memory_base.is_null() && + deallocated_opaques.count(memory_base.opaque()) == 0) { + deallocated_opaques.insert(memory_base.opaque()); + TF_CHECK_OK( + this->allocator_->Deallocate(this->device_ordinal(), &memory_base)); + } + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h new file mode 100644 index 0000000000..aa3b932c4e --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ + +#include + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class which encapsulates a buffer or set of buffers containing data of a +// particular XLA shape. Used for zero-copy execution interface for a +// XLA client running in the same process as the service (LocalClient), +class ShapedBuffer { + public: + // Creates a ShapedBuffer of arbitrary shape. All buffer pointers + // (DeviceMemoryBase) in the returned ShapedBuffer are initialized to null. + static StatusOr> MakeShapedBuffer( + const Shape& shape, const perftools::gputools::Platform* platform, + int device_ordinal); + + // Convenience method which creates a ShapedBuffer of array shape (not a + // tuple). Its single buffer pointer is set to the given value "buffer". The + // given buffer must be large enough to store the given shape as given by + // ShapeUtil::ByteSizeOf. + static StatusOr> MakeArrayShapedBuffer( + const Shape& shape, const perftools::gputools::Platform* platform, + int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); + + // Convenience method which creates a ShapedBuffer of a non-nested tuple. The + // buffer pointers in the return ShapedBuffer are set to the given + // "buffers". The size of buffers must match the number of elements in the + // tuple shape and be large enough to store their respective shape as given by + // ShapeUtil::ByteSizeOf. + static StatusOr> MakeUnnestedTupleShapedBuffer( + const Shape& shape, const perftools::gputools::Platform* platform, + int device_ordinal, + const tensorflow::gtl::ArraySlice + buffers); + + const Shape& shape() const { return shape_; } + const perftools::gputools::Platform* platform() const { return platform_; } + int device_ordinal() const { return device_ordinal_; } + + // Returns the buffer at the given shape index where index is defined as in + // ShapeUtil::GetSubshape. + const perftools::gputools::DeviceMemoryBase& buffer( + const ShapeIndex& index) const; + perftools::gputools::DeviceMemoryBase* mutable_buffer( + const ShapeIndex& index); + + // Returns the underlying structure which stores the buffer pointers. + const std::vector& buffers() const { + return buffers_; + } + std::vector* mutable_buffers() { + return &buffers_; + } + + // Returns the tree of indices which map to buffer pointers. + const ShapeTree& shape_index_to_buffer_entry() const { + return shape_index_to_buffer_entry_; + } + ShapeTree* mutable_shape_index_to_buffer_entry() { + return &shape_index_to_buffer_entry_; + } + + protected: + ShapedBuffer(const Shape& shape, + const perftools::gputools::Platform* platform, + int device_ordinal); + + // The shape of the device buffer with layout. + const Shape shape_; + + // The list of DeviceMemoryBase pointers representing this shape. + // Note that there can be a many to one relationship between tuple elements + // and buffers. To account for this, shape_index_to_buffer_entry_ allows us + // to make from a position in a shape to an index into this list. + std::vector buffers_; + + // The tree of indices into buffers_. + ShapeTree shape_index_to_buffer_entry_; + + // The platform the memory is allocated on. + const perftools::gputools::Platform* platform_; + + // The device the memory is allocated on. + const int device_ordinal_; +}; + +// ShapedBuffer derived class which allocates all internal buffers on +// construction and deallocates the memory when the object is +// destructed. +class ScopedShapedBuffer : public ShapedBuffer { + public: + // Return a new ScopedShapedBuffer of an arbitrary shape. All buffers in the + // ScopedShapedBuffers are automatically allocated to exactly the size of + // their respective array shape. + static StatusOr> MakeScopedShapedBuffer( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal); + + // All buffers in the shape are deallocated on destruction. + ~ScopedShapedBuffer(); + + protected: + ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, + int device_ordinal); + ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; + void operator=(const ScopedShapedBuffer&) = delete; + + DeviceMemoryAllocator* allocator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc new file mode 100644 index 0000000000..c7f6a13023 --- /dev/null +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/transfer_manager.h" + +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/* static */ tensorflow::mutex* +TransferManager::platform_transfer_manager_mutex() { + static tensorflow::mutex* m = new tensorflow::mutex; + return m; +} + +/* static */ std::map* +TransferManager::GetPlatformTransferManagers() { + static auto* r = + new std::map; + return r; +} + +/* static */ void TransferManager::RegisterTransferManager( + se::Platform::Id platform_id, + TransferManagerCreationFunction creation_function) { + tensorflow::mutex_lock lock( + *TransferManager::platform_transfer_manager_mutex()); + auto* managers = GetPlatformTransferManagers(); + CHECK(managers->find(platform_id) == managers->end()); + (*managers)[platform_id].creation_function = creation_function; +} + +/* static */ StatusOr TransferManager::GetForPlatform( + const se::Platform* platform) { + tensorflow::mutex_lock lock( + *TransferManager::platform_transfer_manager_mutex()); + auto* managers = GetPlatformTransferManagers(); + + auto it = managers->find(platform->id()); + if (it == managers->end()) { + return NotFound( + "could not find registered transfer manager for platform %s -- check " + "target linkage", + platform->Name().c_str()); + } + + if (it->second.manager == nullptr) { + // Lazily create the transfer manager the first time it is needed + it->second.manager = (*it->second.creation_function)(); + } + + return it->second.manager; +} + +Status TransferManager::TransferBufferFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + int64 size, void* destination) { + if (source.size() < size) { + return FailedPrecondition( + "Source allocation on device not large enough for data tranfer: " + "%lld < %lld", + source.size(), size); + } + auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination); + if (!copy_status.ok()) { + return AddStatus( + Status(static_cast(copy_status.code()), + copy_status.error_message()), + "failed transfer from device to buffer"); + } + return Status::OK(); +} + +Status TransferManager::TransferBufferToDevice( + se::StreamExecutor* executor, int64 size, const void* source, + se::DeviceMemoryBase* destination) { + if (destination->size() < size) { + return FailedPrecondition( + "Destination allocation on device not large enough for data tranfer: " + "%lld < %lld", + destination->size(), size); + } + auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination); + if (!copy_status.ok()) { + return AddStatus( + Status(static_cast(copy_status.code()), + copy_status.error_message()), + "failed transfer of buffer to device"); + } + return Status::OK(); +} + +StatusOr> +TransferManager::GatherBufferPointersFromTuple( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsTuple(shape)); + + std::set buffer_pointers; + buffer_pointers.insert(source); + + TF_ASSIGN_OR_RETURN(std::vector tuple_elements, + ShallowCopyTupleFromDevice(executor, source, shape)); + for (auto i = 0; i < tuple_elements.size(); ++i) { + const Shape& element_shape = shape.tuple_shapes(i); + if (ShapeUtil::IsTuple(element_shape)) { + TF_ASSIGN_OR_RETURN( + std::set buffer_pointers_in_element, + GatherBufferPointersFromTuple(executor, tuple_elements[i], + element_shape)); + buffer_pointers.insert(buffer_pointers_in_element.begin(), + buffer_pointers_in_element.end()); + } else { + buffer_pointers.insert(tuple_elements[i]); + } + } + return std::move(buffer_pointers); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h new file mode 100644 index 0000000000..90dc921b7d --- /dev/null +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// The TransferManager interface lets backends provide platform-specific +// mechanisms for constructing literals from given device memory handles. +// This lets each platform customize how literals are transferred to/from the +// device in terms of padding, leading dimension, etc. +class TransferManager { + public: + virtual ~TransferManager() {} + + // Returns the ID of the platform that this transfer manager acts on. + virtual perftools::gputools::Platform::Id PlatformId() const = 0; + + // Transfers the region into the provided literal using the provided + // executor. device_shape is the shape, including layout, of the data on the + // device, while literal_shape will be the shape for the literal. device_shape + // and literal_shape must be compatible, but need not have the same layout. + virtual Status TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& region, + const Shape& device_shape, const Shape& literal_shape, + Literal* literal) = 0; + + // Transfers the given literal into the provided region output parameter, + // using the given executor. + virtual Status TransferLiteralToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + perftools::gputools::DeviceMemoryBase* region) = 0; + + // Transfers the given literal into the Infeed interface of the device, + // using the given executor. + virtual Status TransferLiteralToInfeed( + perftools::gputools::StreamExecutor* executor, + const Literal& literal) = 0; + + // Resets the device that the given executor runs on. + virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0; + + // Shallow copy a tuple from the device and create a DeviceMemoryBase object + // for each element in the tuple. A DeviceMemoryBase object refers to the + // buffer containing the data of that element. The DeviceMemoryBase objects + // are returned as a vector. + virtual StatusOr> + ShallowCopyTupleFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, + const Shape& shape) = 0; + + // Returns all buffer pointers that the tuple `source` refers to. Unlike + // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested + // tuples as well. Also, the returned DeviceMemoryBase objects are + // deduplicated. + StatusOr> + GatherBufferPointersFromTuple( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, const Shape& shape); + + // Determines the byte size requirement for the given shape on the underlying + // architecture. This will be used to allocate an appropriately sized memory + // region for a host-to-device transfer. + virtual int64 GetByteSizeRequirement(const Shape& shape) = 0; + + // Transfer a memory block of the given size from the device source into the + // 'destination' buffer. + // + // size is the size to transfer to destination in bytes. + virtual Status TransferBufferFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, int64 size, + void* destination); + + // Transfer a memory block of the given size from 'source' buffer to the given + // destination of the device. + // + // size is the size to transfer from source in bytes. + virtual Status TransferBufferToDevice( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source, perftools::gputools::DeviceMemoryBase* destination); + + typedef TransferManager* (*TransferManagerCreationFunction)(); + + ///// + // The TransferManager class also serves as a point to register objects for + // the various platforms. + + // Registers the TransferManager singleton for the platform kind. This is + // assumed to be a singleton, so no ownership is transferred. + // + // Precondition: a platform kind must not be registered more than once. + static void RegisterTransferManager( + perftools::gputools::Platform::Id platform_id, + TransferManagerCreationFunction transfer_manager); + + // Returns the transfer manager singleton pointer if it is available for the + // given platform, or an error status if it is not. + static StatusOr GetForPlatform( + const perftools::gputools::Platform* platform); + + private: + // Routine that returns the mutex that guards the + // platform-to-transfer manager map. Done as a routine to + // ensure correct initialization ordering, since RegisterTransferManager + // can be called during program initialization time. + static tensorflow::mutex* platform_transfer_manager_mutex(); + + // State kept for each kind of TransferManager. Registration functions + // set up creation_function, and then we use that to lazily create + // "manager" the first time GetForPlatform is invoked for a particular id. + struct State { + TransferManager* manager = nullptr; + TransferManagerCreationFunction creation_function = nullptr; + }; + + // Map from platform kind to transfer manager singleton. + static std::map* + GetPlatformTransferManagers(); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc new file mode 100644 index 0000000000..564111c4f2 --- /dev/null +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +namespace { + +class CpuTransferManagerTest : public ::testing::Test { + protected: + CpuTransferManagerTest() : transfer_manager_(se::host::kHostPlatformId) { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) + .ValueOrDie(); + stream_exec_ = + platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie(); + } + + ~CpuTransferManagerTest() override {} + + se::StreamExecutor* stream_exec_; + GenericTransferManager transfer_manager_; +}; + +TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { + std::vector storage(sizeof(uint32), '\x00'); + se::DeviceMemoryBase memptr(storage.data(), storage.size()); + std::unique_ptr literal = LiteralUtil::CreateR0(42); + TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, + &memptr)); + + CHECK_EQ(42, *reinterpret_cast(&storage[0])); +} + +TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { + std::vector storage(4 * sizeof(float), '\x00'); + se::DeviceMemoryBase memptr(storage.data(), storage.size()); + std::unique_ptr literal = + LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, + &memptr)); + + CHECK_EQ(1.25f, *reinterpret_cast(&storage[0])); + CHECK_EQ(2.5f, *reinterpret_cast(&storage[sizeof(float)])); + CHECK_EQ(-17.0f, *reinterpret_cast(&storage[2 * sizeof(float)])); + CHECK_EQ(-20.125f, *reinterpret_cast(&storage[3 * sizeof(float)])); +} + +TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { + std::vector storage(16, '\x00'); + se::DeviceMemoryBase memptr(storage.data(), storage.size()); + const char* str = "0123456789abcdef"; + std::unique_ptr literal = LiteralUtil::CreateR1U8(str); + TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, + &memptr)); + + CHECK_EQ('0', storage[0]); + CHECK_EQ('8', storage[8]); + CHECK_EQ('f', storage[15]); +} + +TEST_F(CpuTransferManagerTest, TransferR0U32FromDevice) { + std::vector storage(1, 42); + se::DeviceMemoryBase memptr(storage.data(), + storage.size() * sizeof(storage[0])); + Literal literal; + const Shape shape = ShapeUtil::MakeShape(U32, {}); + TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( + stream_exec_, memptr, shape, shape, &literal)); + + LiteralTestUtil::ExpectR0Equal(42, literal); +} + +TEST_F(CpuTransferManagerTest, TransferR1F32FromDevice) { + std::vector storage{1.25f, 2.5f, -17.0f, -20.125f}; + se::DeviceMemoryBase memptr(storage.data(), + storage.size() * sizeof(storage[0])); + Literal literal; + const Shape shape = ShapeUtil::MakeShape(F32, {4}); + TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( + stream_exec_, memptr, shape, shape, &literal)); + + LiteralTestUtil::ExpectR1Equal({1.25, 2.5, -17.0, -20.125}, literal); +} + +TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { + std::vector storage{'k', 'l', 'm', 'n'}; + se::DeviceMemoryBase memptr(storage.data(), + storage.size() * sizeof(storage[0])); + Literal literal; + const Shape shape = ShapeUtil::MakeShape(U8, {4}); + TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( + stream_exec_, memptr, shape, shape, &literal)); + CHECK_EQ("klmn", literal.u8s()); +} + +TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { + std::vector storage{1, 5, 42}; + int64 size = storage.size() * sizeof(storage[0]); + se::DeviceMemoryBase memptr(storage.data(), size); + + std::vector dest(3, 0); + TF_CHECK_OK(transfer_manager_.TransferBufferFromDevice(stream_exec_, memptr, + size, dest.data())); + ASSERT_EQ(1, dest[0]); + ASSERT_EQ(5, dest[1]); + ASSERT_EQ(42, dest[2]); +} + +TEST_F(CpuTransferManagerTest, TransferBufferToDevice) { + int64 size = 3 * sizeof(uint64); + std::vector storage(size, 0); + se::DeviceMemoryBase memptr(storage.data(), size); + + std::vector dest{1, 5, 42}; + TF_CHECK_OK(transfer_manager_.TransferBufferToDevice(stream_exec_, size, + dest.data(), &memptr)); + std::vector* storage64 = + reinterpret_cast*>(&storage); + ASSERT_EQ(1, (*storage64)[0]); + ASSERT_EQ(5, (*storage64)[1]); + ASSERT_EQ(42, (*storage64)[2]); +} + +// TODO(b/24679870): add similar tests for GPUs + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc new file mode 100644 index 0000000000..fb4ff1e68e --- /dev/null +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/transpose_folding.h" + +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool IsOperandFoldableToDot(const HloInstruction& hlo) { + return hlo.IsRank2Transpose() && + hlo.user_count() == 1; // The dot is its only user. +} + +bool CanFoldOperandsIntoDot( + const HloInstruction& dot, + const TransposeFolding::IsTransposableGemmFn& is_transposable_gemm) { + if (HloOpcode::kDot != dot.opcode()) { + return false; + } + + if (!is_transposable_gemm(dot)) { + return false; + } + + const HloInstruction* lhs = dot.operand(0); + const HloInstruction* rhs = dot.operand(1); + bool lhs_foldable = IsOperandFoldableToDot(*lhs); + bool rhs_foldable = IsOperandFoldableToDot(*rhs); + if (!lhs_foldable && !rhs_foldable) { + return false; + } + return true; +} + +// Folds the operands of `dot` that are foldable transposes. `computation` is +// the parent HLO computation of `dot`. `module` is the parent HloModule of +// `computation`. +// +// Returns whether the module is changed. +bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) { + std::vector instructions_to_fuse(1, dot); + for (HloInstruction* operand : dot->operands()) { + if (IsOperandFoldableToDot(*operand)) { + instructions_to_fuse.push_back(operand); + } + } + + // Early-exit if no operands are foldable. + if (instructions_to_fuse.size() == 1) { + return false; + } + + computation->CreateFusionInstruction( + instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); + return true; +} + +} // namespace + +TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm) + : HloPass("transpose-folding"), + is_transposable_gemm_(std::move(is_transposable_gemm)) {} + +StatusOr TransposeFolding::Run(HloModule* module) { + // Modifying the graph while traversing is dangerous, so we find all folding + // opportunities before actually folding them. + HloComputation* entry_computation = module->entry_computation(); + + std::vector foldable_dots; + auto visit_fn = [this, &foldable_dots](HloInstruction* instruction) { + if (CanFoldOperandsIntoDot(*instruction, is_transposable_gemm_)) { + foldable_dots.emplace_back(instruction); + } + return tensorflow::Status::OK(); + }; + TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn)); + + bool changed = false; + for (HloInstruction* dot : foldable_dots) { + changed |= FoldTransposeIntoDot(dot, entry_computation); + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h new file mode 100644 index 0000000000..7bec2f2364 --- /dev/null +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass.h" + +namespace xla { + +// HLO pass that folds transpose operators into Dot operators, where the Dot +// operator is implemented by a GEMM kernel that can transpose its inputs. +class TransposeFolding : public HloPass { + public: + // IsTransposableGemmFn should return true iff the instruction argument is + // implemented as a GEMM kernel that supports transposing its arguments. + typedef std::function IsTransposableGemmFn; + explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm); + + StatusOr Run(HloModule* module) override; + + private: + IsTransposableGemmFn is_transposable_gemm_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc new file mode 100644 index 0000000000..09f932e29e --- /dev/null +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/transpose_folding.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +class TransposeFoldingTest : public ::testing::Test { + protected: + void FoldTranspose(HloModule* module) { + TransposeFolding transpose_folding(gpu::ImplementedAsGemm); + EXPECT_IS_OK(transpose_folding.Run(module).status()); + } +}; + +TEST_F(TransposeFoldingTest, FoldTranspose) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, + /*lhs=*/x, /*rhs=*/transpose_y)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(dot)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the fusion. + std::set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* fusion = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + + // The fusion instruction should contain two parameters, one transpose and + // one dot. + EXPECT_EQ(4, fusion->fused_instructions().size()); +} + +TEST_F(TransposeFoldingTest, FoldTransposeConstant) { + auto builder = HloComputation::Builder("entry_computation"); + // 2x1 + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2({{1}, {2}}))); + // 3x2 + HloInstruction* const1 = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + HloInstruction* transpose0 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); + HloInstruction* transpose1 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, + /*lhs=*/transpose0, /*rhs=*/transpose1)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(dot)); + FoldTranspose(&module); + + for (auto& instruction : entry_computation->instructions()) { + if (instruction->opcode() == HloOpcode::kFusion) { + CHECK_EQ(2, instruction->operand_count()); + EXPECT_EQ(const0, instruction->operand(0)); + EXPECT_EQ(const1, instruction->operand(1)); + } + } + + // The created fusion instruction should contain two parameters, two + // transposes (one for each parameter) and one dot. + EXPECT_EQ(5, + entry_computation->root_instruction()->fused_instructions().size()); +} + +TEST_F(TransposeFoldingTest, FuseWithConstantOperands) { + auto builder = HloComputation::Builder("entry"); + // (1.0 + 2.0) * (2.0 - 3.0) + HloInstruction* const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* const2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction* const3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + const1->shape(), HloOpcode::kAdd, const1, const2)); + HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( + const2->shape(), HloOpcode::kSubtract, const2, const3)); + HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, add, sub)); + + HloModule module("fuse_with_constant_operands"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(mul)); + HloInstruction* call = module.OutlineExpressionFromComputation( + {add, sub, mul}, "", entry_computation); + EXPECT_EQ(call, entry_computation->root_instruction()); + HloComputation* callee_computation = call->to_apply(); + // The arguments to the call should be const1, const2, and const3. + EXPECT_MATCH(call->operands(), testing::UnorderedMatcher( + const1, const2, const3)); + + // The callee should contain 3 parameters and 3 binary operators. + EXPECT_EQ(6, callee_computation->instructions().size()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc new file mode 100644 index 0000000000..0e0c0b02e3 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -0,0 +1,495 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +string BufferAlias::ToString() const { + return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", + tensorflow::str_util::Join(index_, ","), + "] => ", buffer_->ToString(), ")"); +} + +std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { + out << buffer_alias.ToString(); + return out; +} + +bool PointsToSet::IsAmbiguous() const { + bool ambiguous = false; + TF_CHECK_OK(ForEachElement( + [&ambiguous](const ShapeIndex& /*index*/, bool /*is_leaf*/, + const std::vector& points_to) { + ambiguous |= points_to.size() > 1; + return Status::OK(); + })); + return ambiguous; +} + +bool PointsToSet::IsDistinct() const { + bool distinct = true; + std::set all_points_to; + TF_CHECK_OK(ForEachElement([&distinct, &all_points_to]( + const ShapeIndex& /*index*/, bool /*is_leaf*/, + const std::vector& points_to) { + for (auto& buffer : points_to) { + if (all_points_to.count(buffer) != 0) { + distinct = false; + } + all_points_to.insert(buffer); + } + return Status::OK(); + })); + return distinct; +} + +size_t PointsToSet::size() const { + // Because pointed-to elements may be duplicated we have to create a flattened + // set and return the size. + return CreateFlattenedSet().size(); +} + +std::set PointsToSet::CreateFlattenedSet() const { + std::set flat_set; + TF_CHECK_OK(ForEachElement( + [&flat_set](const ShapeIndex& /*index*/, bool /*is_leaf*/, + const std::vector& buffers) { + flat_set.insert(buffers.begin(), buffers.end()); + return Status::OK(); + })); + return flat_set; +} + +bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { + bool found = false; + TF_CHECK_OK(ForEachElement([&found, &buffer]( + const ShapeIndex& /*index*/, bool /*is_leaf*/, + const std::vector& pointed_to_buffers) { + if (!found && + std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), + &buffer) != pointed_to_buffers.end()) { + found = true; + } + return Status::OK(); + })); + return found; +} + +bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, + const ShapeIndex& index) const { + const std::vector& pointed_to_buffers = element(index); + return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), + &buffer) != pointed_to_buffers.end(); +} + +void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, + const ShapeIndex& index) { + if (ContainsBufferAtIndex(buffer, index)) { + return; + } + mutable_element(index)->push_back(&buffer); +} + +const std::set& PointsToSet::tuple_sources( + const ShapeIndex& index) const { + return tuple_sources_.element(index); +} + +void PointsToSet::add_tuple_source(const ShapeIndex& index, + HloInstruction* tuple) { + tuple_sources_.mutable_element(index)->insert(tuple); +} + +/* static */ StatusOr> +TuplePointsToAnalysis::Run(const HloModule* module) { + std::unique_ptr analysis( + new TuplePointsToAnalysis(module)); + TF_RETURN_IF_ERROR(analysis->Analyze()); + return std::move(analysis); +} + +Status TuplePointsToAnalysis::Analyze() { + points_to_.clear(); + for (auto& computation : module_->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(this)); + for (auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction( + instruction.get(), &instruction_defined_buffers_[instruction.get()])); + + const PointsToSet& points_to_set = GetPointsToSet(instruction.get()); + TF_RETURN_IF_ERROR(points_to_set.ForEachElement([this, &instruction]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& pointed_to_buffers) { + for (const LogicalBuffer* buffer : pointed_to_buffers) { + if (buffer_aliases_.count(buffer) == 0) { + buffer_aliases_.insert({buffer, std::vector()}); + } + buffer_aliases_[buffer].emplace_back(*buffer, instruction.get(), + index); + } + return Status::OK(); + })); + } + } + + XLA_VLOG_LINES(3, ToString()); + + return Status::OK(); +} + +const LogicalBuffer& TuplePointsToAnalysis::NewLogicalBuffer( + HloInstruction* instruction, const ShapeIndex& index) { + CHECK_EQ(logical_buffers_.size(), next_buffer_id_); + logical_buffers_.push_back( + MakeUnique(instruction, index, next_buffer_id_)); + ++next_buffer_id_; + return *logical_buffers_.back(); +} + +Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { + // Create trivial points-to set for instruction. Each points-to set at index i + // contains a single element LogicalBuffer(hlo_instruction, i). This indicates + // that this instruction is the source of all buffers in its own output. + PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction); + TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement( + [this, hlo_instruction](const ShapeIndex& index, bool /*is_leaf*/, + std::vector* buffers) { + const LogicalBuffer& buffer = NewLogicalBuffer(hlo_instruction, index); + buffers->push_back(&buffer); + return Status::OK(); + })); + + if (ShapeUtil::IsTuple(hlo_instruction->shape())) { + // If the hlo instruction is a tuple-shaped, then trivially the instruction + // itself is the source of the tuple. + points_to_set.add_tuple_source({}, hlo_instruction); + } + + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleGetTupleElement( + HloInstruction* get_tuple_element, HloInstruction* operand) { + // GetTupleElement forwards a pointer to a particular element of the tuple + // operand. + int64 element_index = get_tuple_element->tuple_index(); + + PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element); + const PointsToSet& operand_points_to_set = *FindOrDie(points_to_, operand); + + // Copy the points-to set (and tuple sources) at index {element_index} of the + // operand to the points-to set for this GetTupleElement instruction. + TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement([&, this]( + const ShapeIndex& target_index, bool /*is_leaf*/, + std::vector* points_to) { + // Construct an index into the operand by prepending element_index to the + // index for the GetTupleElement instruction's points-to set. + ShapeIndex src_index; + src_index.push_back(element_index); + for (auto element : target_index) { + src_index.push_back(element); + } + + *points_to = operand_points_to_set.element(src_index); + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + return Status::OK(); + })); + + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy, + HloInstruction* operand) { + // A kCopy instruction performs a shallow copy of the operand. The top-level + // buffer (index={}) is newly created, but all other buffers (in the case of a + // tuple shape) come from the operand + PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand); + points_to_set.mutable_element(/*index=*/{})->clear(); + points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}), + /*index=*/{}); + + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { + // A kBitcast instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(bitcast, bitcast->operand(0)); + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); + points_to_set.AddPointedToBuffer(NewLogicalBuffer(tuple, /*index=*/{}), + /*index=*/{}); + + // A tuple contains references to all input operands and transitively any + // references in those operands. + for (int64 i = 0; i < operands.size(); ++i) { + const PointsToSet& operand_points_to_set = + *FindOrDie(points_to_, operands[i]); + + // Copy the points-to set (and tuple sources) of the operand into the + // respective subtree of the tuple instructions points-to set. + TF_RETURN_IF_ERROR(operand_points_to_set.ForEachElement( + [&points_to_set, &operand_points_to_set, i]( + const ShapeIndex& src_index, bool /*is_leaf*/, + const std::vector& points_to) { + ShapeIndex target_index; + target_index.push_back(i); + for (auto element : src_index) { + target_index.push_back(element); + } + + *points_to_set.mutable_element(target_index) = points_to; + + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + return Status::OK(); + })); + } + + points_to_set.add_tuple_source({}, tuple); + + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select, + HloInstruction* /*pred*/, + HloInstruction* on_true, + HloInstruction* on_false) { + // Select allocates a new buffer and then shallow copies the on_true or + // on_false buffer into this new buffer. Which side is chosen cannot be + // determined statically so conservatively set the points-to set to the union + // of these on_true and on_false operands. + // + // First create a copy of the on_true points-to set (and tuple sources), then + // add in elements of the on_false points-to set (tuple sources). + PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true); + const PointsToSet& false_points_to_set = *FindOrDie(points_to_, on_false); + TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement( + [&](const ShapeIndex& index, bool /*is_leaf*/, + std::vector* buffers) { + for (const LogicalBuffer* false_buffer : + false_points_to_set.element(index)) { + points_to_set.AddPointedToBuffer(*false_buffer, index); + } + + for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) { + points_to_set.add_tuple_source(index, tuple); + } + return Status::OK(); + })); + + // Select creates a new (top-level) buffer to store its result, so its + // respective element in the points-to set should contain only itself. + points_to_set.mutable_element({})->clear(); + points_to_set.AddPointedToBuffer(NewLogicalBuffer(select, /*index=*/{}), + /*index=*/{}); + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleFusion(HloInstruction* fusion) { + return ShapeUtil::IsTuple(fusion->shape()) + ? Unimplemented("HandleFusion with tuple output") + : DefaultAction(fusion); +} + +const PointsToSet& TuplePointsToAnalysis::GetPointsToSet( + const HloInstruction* hlo_instruction) const { + return *FindOrDie(points_to_, hlo_instruction); +} + +PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( + const HloInstruction* instruction) { + CHECK_EQ(0, points_to_.count(instruction)); + points_to_[instruction] = MakeUnique(instruction->shape()); + return *FindOrDie(points_to_, instruction); +} + +bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const { + const std::vector& buffers = + GetPointsToSet(instruction).element(index); + return (buffers.size() == 1 && buffers[0]->instruction() == instruction); +} + +Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { + if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + } + + if (buffer.id() < 0 || buffer.id() >= next_buffer_id_) { + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: invalid id %lld", + buffer.ToString().c_str(), buffer.id()); + } + if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || + GetBuffer(buffer.id()).index() != buffer.index()) { + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", + buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str()); + } + + return Status::OK(); +} + +const LogicalBuffer& TuplePointsToAnalysis::GetBuffer( + LogicalBuffer::Id id) const { + CHECK_GE(id, 0); + CHECK_LT(id, logical_buffers_.size()); + return *logical_buffers_[id]; +} + +StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( + const HloInstruction* instruction, const ShapeIndex& index) const { + const std::vector& buffers = + GetPointsToSet(instruction).element(index); + if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { + return FailedPrecondition( + "instruction %s does not define buffer at index {%s}", + instruction->name().c_str(), + tensorflow::str_util::Join(index, ",").c_str()); + } + return buffers[0]; +} + +const std::vector& TuplePointsToAnalysis::GetBufferAliases( + const LogicalBuffer& buffer) const { + return buffer_aliases_.at(&buffer); +} + +const std::vector& +TuplePointsToAnalysis::GetBuffersDefinedByInstruction( + const HloInstruction* instruction) const { + return instruction_defined_buffers_.at(instruction); +} + +Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( + const HloInstruction* instruction, + std::vector* buffers) { + return GetPointsToSet(instruction) + .ForEachElement([this, buffers, instruction]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& source_buffers) { + // Add buffers which 'instruction' is the source of. + CHECK(!source_buffers.empty()); + if (source_buffers.size() == 1 && + source_buffers[0]->instruction() == instruction) { + // If this instruction is the source of this buffer the + // indices must match. + DCHECK(source_buffers[0]->index() == index); + buffers->push_back(source_buffers[0]); + } else { + // If the points-to set includes more than one buffer then + // necessarily this instruction did not produce the + // buffer. + for (const LogicalBuffer* source_buffer : source_buffers) { + DCHECK(source_buffer->instruction() != instruction); + } + } + return Status::OK(); + }); +} + +PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( + const HloInstruction* instruction, const HloInstruction* src) { + // PointsToSet doesn't have a copy constructor so copy over element-by-element + // from src PointsToSet. + PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); + const PointsToSet& src_points_to_set = GetPointsToSet(src); + TF_CHECK_OK(dst_points_to_set.ForEachMutableElement( + [this, &dst_points_to_set, &src_points_to_set]( + const ShapeIndex& index, bool /*is_leaf*/, + std::vector* buffers) { + *buffers = src_points_to_set.element(index); + for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { + dst_points_to_set.add_tuple_source(index, tuple_source); + } + return Status::OK(); + })); + return *FindOrDie(points_to_, instruction); +} + +string TuplePointsToAnalysis::ToString() const { + string output = tensorflow::strings::Printf( + "TuplePointsToSet for module %s:\n", module_->name().c_str()); + for (auto& computation : module_->computations()) { + tensorflow::strings::StrAppend(&output, "computation ", + computation->name().c_str(), ":\n"); + for (const HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + tensorflow::strings::StrAppend(&output, " instruction ", + instruction->ToShortString(), ":\n"); + const PointsToSet& points_to_set = GetPointsToSet(instruction); + TF_CHECK_OK(points_to_set.ForEachElement( + [&output](const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& points_to) { + tensorflow::strings::StrAppend( + &output, " {", tensorflow::str_util::Join(index, ","), "}: ", + tensorflow::str_util::Join( + points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); + return Status::OK(); + })); + } + for (auto& buffer : logical_buffers_) { + tensorflow::strings::StrAppend(&output, " buffer ", buffer->ToString(), + ":\n"); + for (const BufferAlias& buffer_alias : buffer_aliases_.at(buffer.get())) { + tensorflow::strings::StrAppend(&output, " alias ", + buffer_alias.ToString(), "\n"); + } + } + } + + tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); + for (const auto& buffer : logical_buffers_) { + tensorflow::strings::StrAppend(&output, " ", buffer->ToString()); + } + return output; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h new file mode 100644 index 0000000000..7a3eb772d6 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -0,0 +1,268 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A class describing the source(s) of the Buffer(s) contained in the output of +// a particular HLO instruction. The structure of PointsToSet mirrors the +// structure of the instruction's shape which may be an arbitrary tree (eg, a +// nested tuple). Each node in this tree corresponds to a single buffer in the +// instruction's output and contains the set of Buffers which might define +// the corresponding buffer. +class PointsToSet : public ShapeTree> { + public: + explicit PointsToSet(const Shape& shape) + : ShapeTree>(shape), + tuple_sources_(shape) {} + + // Returns true if any points-to sets for any subshape element is not a + // singleton. + bool IsAmbiguous() const; + + // Returns true if no LogicalBuffer appears in more than one points-to set of + // the shape nodes. + bool IsDistinct() const; + + // Returns the total number of different LogicalBuffers contained in this + // object. This is equal to CreateFlattenedSet().size(). + size_t size() const; + + // Creates a set containing the union of all LogicalBuffers contained in the + // PointsToSet. + std::set CreateFlattenedSet() const; + + // Returns true if the given buffer is in the points-to set at the given + // index. + bool ContainsBufferAtIndex(const LogicalBuffer& buffer, + const ShapeIndex& index) const; + + // Returns true if the given buffer is in the points-to set at any index. + bool ContainsBuffer(const LogicalBuffer& buffer) const; + + // Adds the given buffer to the points-to set at the given index. This is a + // nop if the buffer already is in the set at that index. + void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index); + + // For the subshape at the given index (where index is defined as in + // ShapeUtil::GetSubshape) this method returns the set of HLO instructions + // which may produce the tuple subshape at that index. For example, given: + // + // %tuple1 = tuple(...) + // %tuple2 = tuple(...) + // %select = select(%tuple1, %tuple2) + // %nested_tuple = tuple(%select, %tuple1) + // + // These are the values for tuple_sources() for the PointsToSet of + // %nested_tuple: + // + // tuple_sources({}) = {%nested_tuple} + // tuple_sources({0}) = {%tuple1, %tuple2} + // tuple_sources({1}) = {%tuple1} + // + // tuple_sources() at the index of an array shape (not a tuple) returns the + // empty set. The instructions in the set returned by tuple_sources + // necessarily are either Tuple instructions, constants, or parameters. + const std::set& tuple_sources(const ShapeIndex& index) const; + + // Add a tuple source instruction for the given index. + void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); + + private: + ShapeTree> tuple_sources_; + + // PointsToSet contains references (const LogicalBuffer*) to elements within + // TuplePointsToAnalysis so disable copying. + TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet); +}; + +// This class describes a particular subshape in a computation (instruction and +// shape index) and the logical buffer which may be a source of the subshape +// value. +class BufferAlias { + public: + BufferAlias(const LogicalBuffer& buffer, HloInstruction* instruction, + const ShapeIndex& index) + : buffer_(&buffer), instruction_(instruction), index_(index) {} + + // Return the logical buffer aliased at the instruction and index. + const LogicalBuffer& buffer() const { return *buffer_; } + + // Return the instruction/index of the subshape. + HloInstruction* instruction() const { return instruction_; } + const ShapeIndex& index() const { return index_; } + + bool operator==(const BufferAlias& other) const { + return buffer_ == other.buffer_ && instruction_ == other.instruction_ && + index_ == other.index_; + } + bool operator!=(const BufferAlias& other) const { return !(*this == other); } + + string ToString() const; + + private: + const LogicalBuffer* buffer_; + HloInstruction* instruction_; + const ShapeIndex index_; +}; + +std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); + +// DFS visitor that performs tuple points-to analysis. This analysis determines +// the potential sources of each buffer in each instruction's output. +class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { + public: + static StatusOr> Run( + const HloModule* module); + + // Return the points-to set of an instruction. This describes the potential + // sources of each buffer in the instruction's output. + const PointsToSet& GetPointsToSet( + const HloInstruction* hlo_instruction) const; + + // Returns the logical buffer with the given ID. + const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; + + // Returns the buffer defined at the given instruction and index. An error is + // returned if no buffer is defined at that point. + StatusOr GetBufferDefinedAt( + const HloInstruction* instruction, const ShapeIndex& index) const; + + // Return a vector containing all BufferAliases of the given logical buffer + // This trivially includes the BufferAlias with same instruction and index as + // the logical buffer itself, so the returned vector is never empty. The + // buffer alias set is the inverse of the points-to set. That is, + // LogicalBuffer B is in the points-to set of instruction I at index N iff + // instruction I, index N is a BufferAlias of B. + const std::vector& GetBufferAliases( + const LogicalBuffer& buffer) const; + + // Return a vector containing all logical buffers in the module. + const std::vector>& logical_buffers() const { + return logical_buffers_; + } + + // Returns a vector of buffers that the instruction produces. Most + // instructions produce a single buffer (the top-level buffer), some produce + // no buffers (eg bitcast), and some produce more than one buffer (eg, + // tuple-shaped parameters). + const std::vector& GetBuffersDefinedByInstruction( + const HloInstruction* instruction) const; + + // Returns true if the given instruction defines a buffer at the given index. + bool InstructionDefinesBufferAtIndex(HloInstruction* instruction, + const ShapeIndex& index) const; + + // Returns an OK status if the given buffer is defined by instruction + // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer + // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns + // an FailedPrecondition error status otherwise. An example of a LogicalBuffer + // which is not defined is a tuple element in a Tuple instruction. In this + // case, the Tuple instruction does not define the LogicalBuffer, rather that + // index aliases one of its operands. + Status VerifyBuffer(const LogicalBuffer& buffer) const; + + Status DefaultAction(HloInstruction* hlo_instruction) override; + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override; + + string ToString() const; + + private: + explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {} + + // Perform the analysis. Should be called immediately after constructing the + // object and before calling GetPointsToSet. + Status Analyze(); + + // Create a new logical buffer and return a reference to it. The newly created + // buffer is stored in an internal vector of LogicalBuffers and can be + // accessed with GetBuffer. + const LogicalBuffer& NewLogicalBuffer(HloInstruction* instruction, + const ShapeIndex& index); + + // Creates an empty PointsToSet in the points_to_ map for the given + // instruction. + PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction); + + // Creates a PointsToSet in the points_to_ map for 'instruction' which is a + // copy of the existing PointsToSet for 'src'. + PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction, + const HloInstruction* src); + + // Adds the buffers defined by the given instruction to the given vector. + Status GatherBuffersDefinedByInstruction( + const HloInstruction* instruction, + std::vector* buffers); + + // The module this analysis is performed on. + const HloModule* module_; + + // A map containing a PointsToSet for every HLO instruction. + tensorflow::gtl::FlatMap> + points_to_; + + // A map containing the LogicalBuffers defined by each HLO instruction. + std::unordered_map> + instruction_defined_buffers_; + + std::unordered_map> + buffer_aliases_; + + // All logical buffers in the module, indexed by LogicalBuffer::Id. Keep as + // vector of std::unique_ptr to keep the underlying pointer values stable. + std::vector> logical_buffers_; + + // The ID of the next logical buffer created. + LogicalBuffer::Id next_buffer_id_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc new file mode 100644 index 0000000000..e4dd4d309e --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -0,0 +1,544 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class TuplePointsToAnalysisTest : public HloTestBase { + protected: + // Builds a module with the given entry computation and runs points to + // analysis. + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + module_.reset(new HloModule(TestName())); + module_->AddEntryComputation(std::move(computation)); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + // Returns the LogicalBuffer defined at the given instruction and + // index. CHECKs if no buffer is defined at that point. + const LogicalBuffer* const GetBuffer(const HloInstruction* instruction, + const ShapeIndex& index) { + const std::vector& pointed_to = + points_to_analysis_->GetPointsToSet(instruction).element(index); + CHECK_EQ(1, pointed_to.size()); + CHECK_EQ(instruction, pointed_to[0]->instruction()); + CHECK(index == pointed_to[0]->index()); + return pointed_to[0]; + } + + // Checks that the given points-to set contains exactly (unordered) the given + // LogicalBuffers. + void ExpectHasBuffers( + const std::vector& points_to_set, + tensorflow::gtl::ArraySlice buffers) { + std::vector vec(buffers.begin(), buffers.end()); + EXPECT_MATCH(points_to_set, testing::UnorderedElementsAre(vec)); + } + + // Checks that the given points-to set contains exactly (unordered) the + // top-level buffers of the given instructions. + void ExpectHasTopLevelBuffers( + const std::vector& points_to_set, + tensorflow::gtl::ArraySlice instructions) { + std::vector buffers; + for (auto instruction : instructions) { + buffers.push_back(GetBuffer(instruction, /*index=*/{})); + } + ExpectHasBuffers(points_to_set, buffers); + } + + // Overload which takes a std::set instead of a std::vector. + void ExpectHasTopLevelBuffers( + const std::set& points_to_set, + tensorflow::gtl::ArraySlice instructions) { + ExpectHasTopLevelBuffers(std::vector( + points_to_set.begin(), points_to_set.end()), + instructions); + } + + // Checks that the buffer defined at the given instruction and index has + // aliases which are exactly (unordered) the given instruction/index pairs. + void ExpectHasBufferAliases( + const HloInstruction* instruction, const ShapeIndex& index, + tensorflow::gtl::ArraySlice> + expected) { + const LogicalBuffer* buffer = + points_to_analysis_->GetBufferDefinedAt(instruction, index) + .ValueOrDie(); + std::vector expected_aliases; + for (auto& pair : expected) { + expected_aliases.push_back(BufferAlias(*buffer, pair.first, pair.second)); + } + EXPECT_MATCH(points_to_analysis_->GetBufferAliases(*buffer), + testing::UnorderedElementsAre(expected_aliases)); + } + + std::unique_ptr module_; + std::unique_ptr points_to_analysis_; +}; + +// Expect the given std::set as A contains exactly the given +// HloInstruction*s as __VA_ARGS__. +#define EXPECT_ISET(A, ...) \ + EXPECT_MATCH(testing::SetToVec(A), \ + testing::UnorderedMatcher(__VA_ARGS__)) + +TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + + BuildModuleAndRunAnalysis(builder.Build()); + EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1}); + EXPECT_TRUE( + points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct()); + + EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2}); + EXPECT_TRUE( + points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty()); + + EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); + EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + tuple); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), + {constant1, constant2, tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2}); + + const PointsToSet& tuple_points_to_set = + points_to_analysis_->GetPointsToSet(tuple); + EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex( + *GetBuffer(constant1, {}), {0})); + EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex( + *GetBuffer(constant2, {}), {1})); + EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex( + *GetBuffer(constant2, {}), {0})); + EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {}))); + EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {}))); +} + +TEST_F(TuplePointsToAnalysisTest, NestedTuple) { + // Create a (nested) tuple containing an inner tuple. The points-to set of the + // outer tuple should contain all elements of the points-to set of the inner + // tuple. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto inner_tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({inner_tuple, constant3})); + + BuildModuleAndRunAnalysis(builder.Build()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3}); + + EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(), + {constant1, constant2, inner_tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(inner_tuple).element({}), + {inner_tuple}); + EXPECT_ISET( + points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}), + inner_tuple); + + EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), + {constant1, constant2, constant3, inner_tuple, tuple}); + + EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + tuple); + EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}), + inner_tuple); + EXPECT_TRUE( + points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3}); +} + +TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { + // Create a nested tuple, then extract the inner tuple with GetTupleElement. + // The points-to set of the GetTupleElement should be the same as the inner + // tuple. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto inner_tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({inner_tuple, constant3})); + + auto get_tuple_element = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0)); + + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element); + EXPECT_EQ(3, points_to_set.size()); + EXPECT_FALSE(points_to_set.IsAmbiguous()); + EXPECT_TRUE(points_to_set.IsDistinct()); + ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), + {constant1, constant2, inner_tuple}); + ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple}); + + EXPECT_ISET(points_to_set.tuple_sources({}), inner_tuple); +} + +TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { + // Create a tuple which contains duplicate elements. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant, constant, constant})); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), + {constant, tuple}); +} + +TEST_F(TuplePointsToAnalysisTest, TupleCopy) { + // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be + // the same. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), + {constant1, constant2, tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(copy).element({}), {copy}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(), + {constant1, constant2, copy}); +} + +TEST_F(TuplePointsToAnalysisTest, TupleSelect) { + // Select from two different tuples. This should create an ambiguous points to + // set containing the union of both sides. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant2, constant2})); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(select); + EXPECT_EQ(3, points_to_set.size()); + EXPECT_TRUE(points_to_set.IsAmbiguous()); + EXPECT_FALSE(points_to_set.IsDistinct()); + ExpectHasTopLevelBuffers(points_to_set.element({}), {select}); + ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2}); + ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2}); + ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), + {constant1, constant2, select}); +} + +TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { + // Create a Select which selects between two tuple parameters. Verify the + // points-to sets and tuple sources are properly set. + Shape tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})}); + + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, tuple_shape, "param1")); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, param0, param1)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The points-to set of each element of a tuple parameters should be itself + // with the appropriate index. + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}), + {GetBuffer(param0, {})}); + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}), + {GetBuffer(param0, {0})}); + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}), + {GetBuffer(param0, {1})}); + + // Select's point-to set of its subelements should be the respective + // subelements of param0 and param1. The top-level buffer, however, does not + // alias as it is created by the select instruction. + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}), + {GetBuffer(select, {})}); + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}), + {GetBuffer(param0, {0}), GetBuffer(param1, {0})}); + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}), + {GetBuffer(param0, {1}), GetBuffer(param1, {1})}); + + // Copy should be identical to select other than the top-level buffer. + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}), + {GetBuffer(copy, {})}); + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}), + {GetBuffer(param0, {0}), GetBuffer(param1, {0})}); + ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}), + {GetBuffer(param0, {1}), GetBuffer(param1, {1})}); +} + +TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { + // Select from two identical tuples. The result should not be ambiguous. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(select); + EXPECT_EQ(3, points_to_set.size()); + EXPECT_FALSE(points_to_set.IsAmbiguous()); + EXPECT_TRUE(points_to_set.IsDistinct()); + ExpectHasTopLevelBuffers(points_to_set.element({}), {select}); + ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1}); + ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2}); + ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), + {constant1, constant2, select}); +} + +TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { + // Select from nested tuples. Verify that the nested points-to sets contain + // the right values. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto inner_tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto inner_tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant2, constant2})); + + auto tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1})); + auto tuple2 = + builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(select); + EXPECT_EQ(5, points_to_set.size()); + EXPECT_TRUE(points_to_set.IsAmbiguous()); + EXPECT_FALSE(points_to_set.IsDistinct()); + + // Verify points-to set. + ExpectHasTopLevelBuffers(points_to_set.element({}), {select}); + ExpectHasTopLevelBuffers(points_to_set.element({0}), + {inner_tuple1, inner_tuple2}); + ExpectHasTopLevelBuffers(points_to_set.element({0, 0}), + {constant1, constant2}); + ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2}); + + // Verify tuple sources. + EXPECT_ISET(points_to_set.tuple_sources({}), tuple1, tuple2); + EXPECT_ISET(points_to_set.tuple_sources({0}), inner_tuple1, inner_tuple2); + EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size()); + EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size()); +} + +TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { + // Bitcast is an alias of its operand. A tuple with a bitcast element should + // have the operand of the bitcast in its points-to set. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + constant2->shape(), HloOpcode::kBitcast, constant2)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast})); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size()); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2}); + EXPECT_TRUE( + points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty()); + + EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); + EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + tuple); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), + {constant1, constant2, tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2}); +} + +TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { + // Construct a tuple constant and kCopy it. Verify the points-to set of the + // copy correctly correctly points into the nested elements of the constant. + auto builder = HloComputation::Builder(TestName()); + auto tuple_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()}))); + auto copy = builder.AddInstruction(HloInstruction::CreateUnary( + tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); + + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(copy); + + ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})}); + ExpectHasBuffers(points_to_set.element({0}), + {GetBuffer(tuple_constant, {0})}); + ExpectHasBuffers(points_to_set.element({1}), + {GetBuffer(tuple_constant, {1})}); +} + +TEST_F(TuplePointsToAnalysisTest, BufferAliases) { + // Create a nested tuple in which individual elements appear multiple + // times. Verify buffer alias sets. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto inner_tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({inner_tuple, constant2})); + + BuildModuleAndRunAnalysis(builder.Build()); + + ExpectHasBufferAliases( + constant1, /*index=*/{}, + {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}}); + ExpectHasBufferAliases( + constant2, /*index=*/{}, + {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}}); + ExpectHasBufferAliases(inner_tuple, /*index=*/{}, + {{inner_tuple, {}}, {tuple, {0}}}); + ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc new file mode 100644 index 0000000000..04029c7b01 --- /dev/null +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -0,0 +1,2117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/user_computation.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { +namespace { + +HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { + switch (unop) { + case UNOP_ABS: + return HloOpcode::kAbs; + case UNOP_CEIL: + return HloOpcode::kCeil; + case UNOP_EXP: + return HloOpcode::kExp; + case UNOP_FLOOR: + return HloOpcode::kFloor; + case UNOP_LOG: + return HloOpcode::kLog; + case UNOP_LOGICAL_NOT: + return HloOpcode::kLogicalNot; + case UNOP_NEGATE: + return HloOpcode::kNegate; + case UNOP_SIGN: + return HloOpcode::kSign; + case UNOP_SORT: + return HloOpcode::kSort; + case UNOP_TANH: + return HloOpcode::kTanh; + default: + LOG(FATAL) << "unhandled operation " << unop; + } +} + +HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { + switch (binop) { + case BINOP_DOT: + return HloOpcode::kDot; + case BINOP_MUL: + return HloOpcode::kMultiply; + case BINOP_ADD: + return HloOpcode::kAdd; + case BINOP_SUB: + return HloOpcode::kSubtract; + case BINOP_INDEX: + return HloOpcode::kIndex; + case BINOP_DIV: + return HloOpcode::kDivide; + case BINOP_EQ: + return HloOpcode::kEq; + case BINOP_GE: + return HloOpcode::kGe; + case BINOP_GT: + return HloOpcode::kGt; + case BINOP_LE: + return HloOpcode::kLe; + case BINOP_LT: + return HloOpcode::kLt; + case BINOP_NE: + return HloOpcode::kNe; + case BINOP_MAX: + return HloOpcode::kMaximum; + case BINOP_MIN: + return HloOpcode::kMinimum; + case BINOP_POW: + return HloOpcode::kPower; + case BINOP_REM: + return HloOpcode::kRemainder; + case BINOP_LOGICAL_OR: + return HloOpcode::kLogicalOr; + case BINOP_LOGICAL_AND: + return HloOpcode::kLogicalAnd; + default: + LOG(FATAL) << "unhandled operation " << binop; + } +} + +HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { + switch (triop) { + case TRIOP_CLAMP: + return HloOpcode::kClamp; + case TRIOP_SELECT: + return HloOpcode::kSelect; + case TRIOP_UPDATE: + return HloOpcode::kUpdate; + default: + LOG(FATAL) << "unhandled operation " << triop; + } +} + +HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { + switch (varop) { + case VAROP_TUPLE: + return HloOpcode::kTuple; + default: + LOG(FATAL) << "unhandled operation " << varop; + } +} + +} // namespace + +/* static */ StatusOr> +UserComputation::MakeWithRemapping( + const SessionComputation& session_computation, + const ComputationHandle& handle, + const std::map& old_to_new) { + auto user_computation = + MakeUnique(session_computation.name(), handle); + { + tensorflow::mutex_lock lock(user_computation->mutex_); + user_computation->session_computation_ = session_computation; + user_computation->next_handle_value_ = + std::max_element(session_computation.requests().begin(), + session_computation.requests().end(), + [](const std::pair& lhs, + const std::pair& rhs) { + return lhs.first < rhs.first; + }) + ->first + + 1; + TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); + } + + return std::move(user_computation); +} + +UserComputation::UserComputation(const string& name, + const ComputationHandle& handle) + : name_(name), next_handle_value_(1) { + *session_computation_.mutable_computation_handle() = handle; + session_computation_.set_name(name); +} + +ComputationDataHandle UserComputation::CreateComputationDataHandle() { + ComputationDataHandle handle; + handle.set_handle(next_handle_value_); + // Handles are used as Version values and *must* be assigned consecutively for + // computation versioning to work. + next_handle_value_++; + return handle; +} + +StatusOr UserComputation::AddParameterInstruction( + const ParameterRequest& parameter_request) { + tensorflow::mutex_lock lock(mutex_); + + int64 parameter_number = parameter_request.parameter(); + if (parameters_.count(parameter_number) != 0) { + return InvalidArgument("parameter %lld already registered", + parameter_number); + } + ComputationDataHandle handle = CreateComputationDataHandle(); + + const Shape& validated_shape = parameter_request.shape(); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = validated_shape; + *request.mutable_request()->mutable_parameter_request() = parameter_request; + + parameters_[parameter_number] = &request; + + return handle; +} + +Status UserComputation::AddSendInstruction(const SendRequest& send_request) { + tensorflow::mutex_lock lock(mutex_); + + *session_computation_.add_send_requests() = send_request; + // Check if the operand of the instruction is valid. + TF_RETURN_IF_ERROR(LookupRequest(send_request.operand()).status()); + return Status::OK(); +} + +StatusOr UserComputation::AddRecvInstruction( + const RecvRequest& recv_request) { + tensorflow::mutex_lock lock(mutex_); + + const Shape& shape = recv_request.shape(); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_recv_request() = recv_request; + + return handle; +} + +StatusOr UserComputation::AddPadInstruction( + const PadRequest& pad_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(pad_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, + LookupRequest(pad_request.padding_value())); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( + operand->output_shape(), + padding_value->output_shape(), + pad_request.padding_config())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + *request.mutable_request()->mutable_pad_request() = pad_request; + + return handle; +} + +StatusOr UserComputation::AddConstantInstruction( + const ConstantRequest& constant_request) { + const Shape& validated_shape = constant_request.literal().shape(); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); + + tensorflow::mutex_lock lock(mutex_); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = validated_shape; + *request.mutable_request()->mutable_constant_request() = constant_request; + + return handle; +} + +StatusOr UserComputation::AddGetTupleElementInstruction( + const GetTupleElementRequest& get_tuple_element_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(get_tuple_element_request.operand())); + Shape element_shape = ShapeUtil::GetTupleElementShape( + operand->output_shape(), get_tuple_element_request.index()); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = element_shape; + *request.mutable_request()->mutable_get_tuple_element_request() = + get_tuple_element_request; + + return handle; +} + +Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { + tensorflow::mutex_lock lock(mutex_); + + // Verify that the operand index is valid. + TF_RETURN_IF_ERROR(LookupRequest(trace_request.operand()).status()); + + *session_computation_.add_trace_requests() = trace_request; + + return Status::OK(); +} + +StatusOr UserComputation::AddRngInstruction( + const RngRequest& rng_request) { + tensorflow::mutex_lock lock(mutex_); + + // Check the number of parameters per RNG distribution. + switch (rng_request.distribution()) { + case RandomDistribution::RNG_BERNOULLI: + if (rng_request.parameter_size() != 1) { + return InvalidArgument( + "RNG distribution (%s) expects 1 parameters, but got %d", + RandomDistribution_Name(rng_request.distribution()).c_str(), + rng_request.parameter_size()); + } + break; + case RandomDistribution::RNG_NORMAL: + case RandomDistribution::RNG_UNIFORM: + if (rng_request.parameter_size() != 2) { + return InvalidArgument( + "RNG distribution (%s) expects 2 parameters, but got %d", + RandomDistribution_Name(rng_request.distribution()).c_str(), + rng_request.parameter_size()); + } + break; + default: + LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); + } + + // Verify that the parameter indices are valid; + for (const ComputationDataHandle& param : rng_request.parameter()) { + TF_RETURN_IF_ERROR(LookupRequest(param).status()); + } + const Shape& validated_shape = rng_request.shape(); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = validated_shape; + *request.mutable_request()->mutable_rng_request() = rng_request; + + return handle; +} + +StatusOr UserComputation::AddMapInstruction( + const MapRequest& map_request, + const UserComputation& to_apply_computation) { + tensorflow::mutex_lock lock(mutex_); + + std::vector operand_shapes; + for (const ComputationDataHandle& handle : map_request.operands()) { + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + operand_shapes.push_back(&operand->output_shape()); + } + + VersionedComputationHandle::Version to_apply_version = + to_apply_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr to_apply_program_shape, + to_apply_computation.ComputeProgramShape(to_apply_version)); + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(to_apply_version); + *request.mutable_request()->mutable_map_request() = map_request; + + return handle; +} + +StatusOr UserComputation::AddReduceInstruction( + const ReduceRequest& reduce_request, + const UserComputation& to_apply_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(reduce_request.operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, + LookupRequest(reduce_request.init_value())); + + VersionedComputationHandle::Version to_apply_version = + to_apply_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr to_apply_program_shape, + to_apply_computation.ComputeProgramShape(to_apply_version)); + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferReduceShape( + operand->output_shape(), init_value->output_shape(), + AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(to_apply_version); + *request.mutable_request()->mutable_reduce_request() = reduce_request; + + return handle; +} + +StatusOr UserComputation::AddReduceWindowInstruction( + const ReduceWindowRequest& reduce_window_request, + const UserComputation& to_apply_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(reduce_window_request.operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, + LookupRequest(reduce_window_request.init_value())); + + VersionedComputationHandle::Version to_apply_version = + to_apply_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr to_apply_program_shape, + to_apply_computation.ComputeProgramShape(to_apply_version)); + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferReduceWindowShape( + operand->output_shape(), init_value->output_shape(), + reduce_window_request.window(), *to_apply_program_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(to_apply_version); + *request.mutable_request()->mutable_reduce_window_request() = + reduce_window_request; + + return handle; +} + +StatusOr UserComputation::AddSelectAndScatterInstruction( + const SelectAndScatterRequest& select_and_scatter_request, + const UserComputation& select_computation, + const UserComputation& scatter_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(select_and_scatter_request.operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* source, + LookupRequest(select_and_scatter_request.source())); + TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, + LookupRequest(select_and_scatter_request.init_value())); + + VersionedComputationHandle::Version select_version = + select_computation.version(); + TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, + select_computation.ComputeProgramShape(select_version)); + VersionedComputationHandle::Version scatter_version = + scatter_computation.version(); + TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, + scatter_computation.ComputeProgramShape(scatter_version)); + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferSelectAndScatterShape( + operand->output_shape(), *select_program_shape, + select_and_scatter_request.window(), source->output_shape(), + init_value->output_shape(), *scatter_program_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(select_version); + request.add_embedded_computation_versions(scatter_version); + *request.mutable_request()->mutable_select_and_scatter_request() = + select_and_scatter_request; + + return handle; +} + +StatusOr UserComputation::AddReverseInstruction( + const ReverseRequest& reverse_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(reverse_request.operand())); + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferReverseShape( + operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + *request.mutable_request()->mutable_reverse_request() = reverse_request; + return handle; +} + +StatusOr UserComputation::AddWhileInstruction( + const WhileRequest& while_request, + const UserComputation& condition_computation, + const UserComputation& body_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* init, + LookupRequest(while_request.init())); + + VersionedComputationHandle::Version condition_version = + condition_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr condition_program_shape, + condition_computation.ComputeProgramShape(condition_version)); + + VersionedComputationHandle::Version body_version = body_computation.version(); + TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, + body_computation.ComputeProgramShape(body_version)); + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferWhileShape( + *condition_program_shape, *body_program_shape, init->output_shape())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(condition_version); + request.add_embedded_computation_versions(body_version); + *request.mutable_request()->mutable_while_request() = while_request; + + return handle; +} + +StatusOr UserComputation::AddBroadcastInstruction( + const BroadcastRequest& broadcast_request) { + tensorflow::mutex_lock lock(mutex_); + + // Fetches and validates the operand. + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(broadcast_request.operand())); + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferBroadcastShape( + operand->output_shape(), + AsInt64Slice(broadcast_request.broadcast_sizes()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + *request.mutable_request()->mutable_broadcast_request() = broadcast_request; + return handle; +} + +StatusOr UserComputation::AddReshapeInstruction( + const ReshapeRequest& reshape_request) { + tensorflow::mutex_lock lock(mutex_); + + // Fetches and validates the operand. + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(reshape_request.operand())); + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferReshapeShape( + operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), + AsInt64Slice(reshape_request.new_sizes()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + *request.mutable_request()->mutable_reshape_request() = reshape_request; + + return handle; +} + +StatusOr UserComputation::AddSliceInstruction( + const SliceRequest& slice_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(slice_request.operand())); + + TF_ASSIGN_OR_RETURN( + Shape new_shape, + ShapeInference::InferSliceShape( + operand->output_shape(), AsInt64Slice(slice_request.start_indices()), + AsInt64Slice(slice_request.limit_indices()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_slice_request() = slice_request; + + return handle; +} + +StatusOr UserComputation::AddDynamicSliceInstruction( + const DynamicSliceRequest& dynamic_slice_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(dynamic_slice_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, + LookupRequest(dynamic_slice_request.start_indices())); + + TF_ASSIGN_OR_RETURN( + Shape new_shape, + ShapeInference::InferDynamicSliceShape( + operand->output_shape(), start_indices->output_shape(), + AsInt64Slice(dynamic_slice_request.slice_sizes()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_dynamic_slice_request() = + dynamic_slice_request; + + return handle; +} + +StatusOr +UserComputation::AddDynamicUpdateSliceInstruction( + const DynamicUpdateSliceRequest& dynamic_update_slice_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(dynamic_update_slice_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* update, + LookupRequest(dynamic_update_slice_request.update())); + + TF_ASSIGN_OR_RETURN( + const OperationRequest* start_indices, + LookupRequest(dynamic_update_slice_request.start_indices())); + + TF_ASSIGN_OR_RETURN(Shape new_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->output_shape(), update->output_shape(), + start_indices->output_shape())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_dynamic_update_slice_request() = + dynamic_update_slice_request; + + return handle; +} + +StatusOr UserComputation::AddConcatenateInstruction( + const ConcatenateRequest& concatenate_request) { + tensorflow::mutex_lock lock(mutex_); + + std::vector operand_shapes; + for (const ComputationDataHandle& handle : concatenate_request.operands()) { + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + operand_shapes.push_back(&operand->output_shape()); + } + + TF_ASSIGN_OR_RETURN(Shape new_shape, + ShapeInference::InferConcatOpShape( + operand_shapes, concatenate_request.dimension())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_concatenate_request() = + concatenate_request; + + return handle; +} + +StatusOr UserComputation::AddConvertInstruction( + const ConvertRequest& convert_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(convert_request.operand())); + + TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( + operand->output_shape(), + convert_request.new_element_type())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_convert_request() = convert_request; + + return handle; +} + +StatusOr UserComputation::AddConvolveInstruction( + const ConvolveRequest& convolve_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookupRequest(convolve_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookupRequest(convolve_request.rhs())); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( + lhs->output_shape(), rhs->output_shape(), + convolve_request.window(), + convolve_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_convolve_request() = convolve_request; + + return handle; +} + +StatusOr UserComputation::AddCrossReplicaSumInstruction( + const CrossReplicaSumRequest& cross_replica_sum_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(cross_replica_sum_request.operand())); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( + operand->output_shape())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_cross_replica_sum_request() = + cross_replica_sum_request; + + return handle; +} + +StatusOr UserComputation::AddInfeedInstruction( + const InfeedRequest& infeed_request) { + tensorflow::mutex_lock lock(mutex_); + + const Shape& shape = infeed_request.shape(); + if (ShapeUtil::IsNestedTuple(shape)) { + return InvalidArgument("Infeed does not support nested tuple shapes"); + } + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_infeed_request() = infeed_request; + + return handle; +} + +StatusOr UserComputation::AddCallInstruction( + const CallRequest& call_request, + const UserComputation& to_apply_computation) { + tensorflow::mutex_lock lock(mutex_); + + std::vector operand_shapes; + for (const ComputationDataHandle& handle : call_request.operands()) { + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + operand_shapes.push_back(&operand->output_shape()); + } + + VersionedComputationHandle::Version to_apply_version = + to_apply_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr to_apply_program_shape, + to_apply_computation.ComputeProgramShape(to_apply_version)); + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(to_apply_version); + *request.mutable_request()->mutable_call_request() = call_request; + + return handle; +} + +StatusOr UserComputation::AddCustomCallInstruction( + const CustomCallRequest& custom_call_request) { + tensorflow::mutex_lock lock(mutex_); + + for (const ComputationDataHandle& handle : custom_call_request.operands()) { + TF_RETURN_IF_ERROR(LookupRequest(handle).status()); + } + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = custom_call_request.shape(); + *request.mutable_request()->mutable_custom_call_request() = + custom_call_request; + + return handle; +} + +StatusOr UserComputation::AddUnaryInstruction( + const UnaryOpRequest& unary_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookupRequest(unary_request.operand())); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), + operand->output_shape())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_unary_op_request() = unary_request; + + return handle; +} + +StatusOr UserComputation::AddBinaryInstruction( + const BinaryOpRequest& binary_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookupRequest(binary_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookupRequest(binary_request.rhs())); + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferBinaryOpShape( + binary_request.binop(), lhs->output_shape(), rhs->output_shape(), + AsInt64Slice(binary_request.broadcast_dimensions()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_binary_op_request() = binary_request; + + return handle; +} + +StatusOr UserComputation::AddTernaryInstruction( + const TernaryOpRequest& ternary_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookupRequest(ternary_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookupRequest(ternary_request.rhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, + LookupRequest(ternary_request.ehs())); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferTernaryOpShape( + ternary_request.triop(), lhs->output_shape(), + rhs->output_shape(), ehs->output_shape())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_ternary_op_request() = ternary_request; + + return handle; +} + +StatusOr UserComputation::AddVariadicInstruction( + const VariadicOpRequest& variadic_request) { + tensorflow::mutex_lock lock(mutex_); + + std::vector operand_shapes; + for (const ComputationDataHandle& handle : variadic_request.operands()) { + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + operand_shapes.push_back(&operand->output_shape()); + } + + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferVariadicOpShape( + variadic_request.varop(), operand_shapes)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_variadic_op_request() = variadic_request; + + return handle; +} + +StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + return operand->output_shape(); +} + +Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { + tensorflow::mutex_lock lock(mutex_); + + if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { + return InvalidArgument("Invalid handle in SetReturnValue"); + } + + handle_to_return_ = handle; + + return Status::OK(); +} + +VersionedComputationHandle UserComputation::GetVersionedHandle() const { + tensorflow::mutex_lock lock(mutex_); + + VersionedComputationHandle versioned_handle; + versioned_handle.handle = session_computation_.computation_handle(); + + if (handle_to_return_.handle() > 0) { + // A specific handle has been requested for the result of the computation. + versioned_handle.version = handle_to_return_.handle(); + } else { + // A version value is simply the most recently assigned + // ComputationDataHandle value, ie the handle value of the root of the + // computation. + versioned_handle.version = next_handle_value_ - 1; + } + + return versioned_handle; +} + +VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( + const ComputationDataHandle& operation) const { + tensorflow::mutex_lock lock(mutex_); + + // The version at which an operation was added is simply the handle value of + // the ComputationDataHandle. + VersionedComputationHandle versioned_handle; + versioned_handle.handle = session_computation_.computation_handle(); + versioned_handle.version = operation.handle(); + return versioned_handle; +} + +VersionedComputationHandle::Version UserComputation::version() const { + return GetVersionedHandle().version; +} + +StatusOr> +UserComputation::ComputeProgramShape( + VersionedComputationHandle::Version version) const { + tensorflow::mutex_lock lock(mutex_); + + CHECK(version > 0 && version < next_handle_value_); + + if (program_shape_ == nullptr || program_shape_version_ != version) { + // ProgramShape has not been computed yet, or is for different + // version. Compute it now. + TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); + + auto program_shape = MakeUnique(); + for (int64 request_num = 1; request_num <= version; ++request_num) { + const OperationRequest& request = + session_computation_.requests().at(request_num); + if (request.request().op_case() == OpRequest::kParameterRequest) { + const ParameterRequest& parameter_request = + request.request().parameter_request(); + int64 param_no = parameter_request.parameter(); + // Parameters may be out of order so expand ProgramShape parameters + // until + // it is at least large enough to hold the current parameter number. + while (program_shape->parameters_size() <= param_no) { + program_shape->add_parameters(); + program_shape->add_parameter_names(); + } + *program_shape->mutable_parameters(param_no) = request.output_shape(); + *program_shape->mutable_parameter_names(param_no) = + parameter_request.name(); + } + } + + // The root determines the output shape. + *program_shape->mutable_result() = GetRoot(version).output_shape(); + if (ShapeUtil::IsOpaque(program_shape->result())) { + return Unimplemented("Computation results cannot be opaque"); + } + + program_shape_ = std::move(program_shape); + program_shape_version_ = version; + } + + return program_shape_; +} + +namespace { + +// A visitor which checks whether an operation is a compile-time constant. That +// is, the operation does not depend on any parameter instructions. The visitor +// walks the computation starting at a given operation and sets is_constant to +// false iff a parameter or RNG operation is encountered. +void ConstantVisitor(const SessionComputation& session_computation, + const ComputationDataHandle& handle, + std::set* visited, bool* is_constant) { + if (visited->count(handle.handle()) != 0 || !*is_constant) { + return; + } + + const OperationRequest& request = + session_computation.requests().at(handle.handle()); + switch (request.request().op_case()) { + case OpRequest::kRngRequest: + *is_constant = false; + break; + + case OpRequest::kConstantRequest: + break; + + case OpRequest::kGetTupleElementRequest: { + const GetTupleElementRequest& get_tuple_element_request = + request.request().get_tuple_element_request(); + ConstantVisitor(session_computation, get_tuple_element_request.operand(), + visited, is_constant); + break; + } + + case OpRequest::kSliceRequest: { + const SliceRequest& slice_request = request.request().slice_request(); + ConstantVisitor(session_computation, slice_request.operand(), visited, + is_constant); + break; + } + + case OpRequest::kDynamicSliceRequest: { + const DynamicSliceRequest& dynamic_slice_request = + request.request().dynamic_slice_request(); + ConstantVisitor(session_computation, dynamic_slice_request.operand(), + visited, is_constant); + ConstantVisitor(session_computation, + dynamic_slice_request.start_indices(), visited, + is_constant); + break; + } + + case OpRequest::kDynamicUpdateSliceRequest: { + const DynamicUpdateSliceRequest& dynamic_update_slice_request = + request.request().dynamic_update_slice_request(); + ConstantVisitor(session_computation, + dynamic_update_slice_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, + dynamic_update_slice_request.update(), visited, + is_constant); + ConstantVisitor(session_computation, + dynamic_update_slice_request.start_indices(), visited, + is_constant); + break; + } + + case OpRequest::kConcatenateRequest: { + const ConcatenateRequest& concatenate_request = + request.request().concatenate_request(); + for (const ComputationDataHandle& handle : + concatenate_request.operands()) { + ConstantVisitor(session_computation, handle, visited, is_constant); + } + break; + } + + case OpRequest::kConvolveRequest: { + const ConvolveRequest& convolve_request = + request.request().convolve_request(); + ConstantVisitor(session_computation, convolve_request.lhs(), visited, + is_constant); + ConstantVisitor(session_computation, convolve_request.rhs(), visited, + is_constant); + break; + } + + case OpRequest::kCrossReplicaSumRequest: { + // TODO(b/33009255): Implmement constant folding for cross replica sum. + *is_constant = false; + break; + } + + case OpRequest::kInfeedRequest: { + *is_constant = false; + break; + } + + case OpRequest::kCallRequest: { + const CallRequest& call_request = request.request().call_request(); + for (const ComputationDataHandle& handle : call_request.operands()) { + ConstantVisitor(session_computation, handle, visited, is_constant); + } + // TODO(b/32495713): We aren't checking the to_apply computation itself, + // so we conservatively say that computations containing the Call op + // cannot be constant. We cannot set is_constant=false in other similar + // cases since we're already relying on IsConstant to return true. + *is_constant = false; + break; + } + + case OpRequest::kCustomCallRequest: { + *is_constant = false; + break; + } + + case OpRequest::kMapRequest: { + const MapRequest& map_request = request.request().map_request(); + for (const ComputationDataHandle& handle : map_request.operands()) { + ConstantVisitor(session_computation, handle, visited, is_constant); + } + // TODO(b/32495713): We aren't checking the to_apply computation itself. + break; + } + + case OpRequest::kReduceRequest: { + const ReduceRequest& reduce_request = request.request().reduce_request(); + ConstantVisitor(session_computation, reduce_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, reduce_request.init_value(), visited, + is_constant); + // TODO(b/32495713): We aren't checking the to_apply computation itself. + break; + } + + case OpRequest::kReduceWindowRequest: { + const ReduceWindowRequest& reduce_window_request = + request.request().reduce_window_request(); + ConstantVisitor(session_computation, reduce_window_request.operand(), + visited, is_constant); + ConstantVisitor(session_computation, reduce_window_request.init_value(), + visited, is_constant); + // TODO(b/32495713): We aren't checking the to_apply computation itself. + break; + } + + case OpRequest::kSelectAndScatterRequest: { + const SelectAndScatterRequest& select_and_scatter_request = + request.request().select_and_scatter_request(); + ConstantVisitor(session_computation, select_and_scatter_request.operand(), + visited, is_constant); + ConstantVisitor(session_computation, select_and_scatter_request.source(), + visited, is_constant); + ConstantVisitor(session_computation, + select_and_scatter_request.init_value(), visited, + is_constant); + // TODO(b/32495713): We aren't checking the select and scatter + // computations themselves. + break; + } + + case OpRequest::kBroadcastRequest: { + const BroadcastRequest& broadcast_request = + request.request().broadcast_request(); + ConstantVisitor(session_computation, broadcast_request.operand(), visited, + is_constant); + break; + } + + case OpRequest::kReshapeRequest: { + const ReshapeRequest& reshape_request = + request.request().reshape_request(); + ConstantVisitor(session_computation, reshape_request.operand(), visited, + is_constant); + break; + } + + case OpRequest::kReverseRequest: { + const ReverseRequest& reverse_request = + request.request().reverse_request(); + ConstantVisitor(session_computation, reverse_request.operand(), visited, + is_constant); + break; + } + + case OpRequest::kPadRequest: { + const PadRequest& pad_request = request.request().pad_request(); + ConstantVisitor(session_computation, pad_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, pad_request.padding_value(), visited, + is_constant); + break; + } + + case OpRequest::kParameterRequest: { + *is_constant = false; + break; + } + + case OpRequest::kConvertRequest: { + const ConvertRequest& convert_request = + request.request().convert_request(); + ConstantVisitor(session_computation, convert_request.operand(), visited, + is_constant); + break; + } + + case OpRequest::kWhileRequest: { + const WhileRequest& while_request = request.request().while_request(); + ConstantVisitor(session_computation, while_request.init(), visited, + is_constant); + // TODO(b/32495713): We aren't checking the condition and body + // computations themselves. + break; + } + + case OpRequest::kTernaryOpRequest: { + const TernaryOpRequest& ternary_op_request = + request.request().ternary_op_request(); + ConstantVisitor(session_computation, ternary_op_request.lhs(), visited, + is_constant); + ConstantVisitor(session_computation, ternary_op_request.rhs(), visited, + is_constant); + ConstantVisitor(session_computation, ternary_op_request.ehs(), visited, + is_constant); + break; + } + + case OpRequest::kVariadicOpRequest: { + const VariadicOpRequest& variadic_op_request = + request.request().variadic_op_request(); + for (const ComputationDataHandle& handle : + variadic_op_request.operands()) { + ConstantVisitor(session_computation, handle, visited, is_constant); + } + break; + } + + case OpRequest::kUnaryOpRequest: { + const UnaryOpRequest& unary_op_request = + request.request().unary_op_request(); + ConstantVisitor(session_computation, unary_op_request.operand(), visited, + is_constant); + break; + } + + case OpRequest::kBinaryOpRequest: { + const BinaryOpRequest& binary_op_request = + request.request().binary_op_request(); + ConstantVisitor(session_computation, binary_op_request.lhs(), visited, + is_constant); + ConstantVisitor(session_computation, binary_op_request.rhs(), visited, + is_constant); + break; + } + + case OpRequest::OP_NOT_SET: + LOG(FATAL) << "OperationRequest doesn't contain a request"; + + default: + LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); + } + visited->insert(handle.handle()); +} + +} // namespace + +StatusOr UserComputation::IsConstant( + const ComputationDataHandle& handle) { + tensorflow::mutex_lock lock(mutex_); + + // Verify that the handle is valid. + auto operation_status = LookupRequest(handle); + if (!operation_status.ok()) { + return operation_status.status(); + } + + bool is_constant = true; + std::set visited; + ConstantVisitor(session_computation_, handle, &visited, &is_constant); + + return is_constant; +} + +const OperationRequest& UserComputation::GetRoot( + VersionedComputationHandle::Version version) const { + CHECK(version > 0 && version < next_handle_value_); + return session_computation_.requests().at(version); +} + +std::vector +UserComputation::GetEmbeddedComputations( + VersionedComputationHandle::Version version) const { + tensorflow::mutex_lock lock(mutex_); + + std::vector computations; + for (const auto& handle_request : session_computation_.requests()) { + int64 handle_value = handle_request.first; + if (handle_value <= version) { + const OperationRequest& request = handle_request.second; + switch (request.request().op_case()) { + case OpRequest::kCallRequest: { + CHECK_EQ(1, request.embedded_computation_versions_size()); + const CallRequest& call_request = request.request().call_request(); + const VersionedComputationHandle versioned_handle = { + call_request.to_apply(), + request.embedded_computation_versions(0)}; + computations.push_back(versioned_handle); + break; + } + + case OpRequest::kMapRequest: { + CHECK_EQ(1, request.embedded_computation_versions_size()); + const MapRequest& map_request = request.request().map_request(); + const VersionedComputationHandle versioned_handle = { + map_request.to_apply(), request.embedded_computation_versions(0)}; + computations.push_back(versioned_handle); + break; + } + + case OpRequest::kReduceRequest: { + CHECK_EQ(1, request.embedded_computation_versions_size()); + const ReduceRequest& reduce_request = + request.request().reduce_request(); + const VersionedComputationHandle versioned_handle = { + reduce_request.to_apply(), + request.embedded_computation_versions(0)}; + computations.push_back(versioned_handle); + break; + } + + case OpRequest::kReduceWindowRequest: { + CHECK_EQ(1, request.embedded_computation_versions_size()); + const ReduceWindowRequest& reduce_window_request = + request.request().reduce_window_request(); + const VersionedComputationHandle versioned_handle = { + reduce_window_request.to_apply(), + request.embedded_computation_versions(0)}; + computations.push_back(versioned_handle); + break; + } + + case OpRequest::kSelectAndScatterRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const SelectAndScatterRequest& select_and_scatter_request = + request.request().select_and_scatter_request(); + const VersionedComputationHandle select_versioned_handle = { + select_and_scatter_request.select(), + request.embedded_computation_versions(0)}; + computations.push_back(select_versioned_handle); + const VersionedComputationHandle scatter_versioned_handle = { + select_and_scatter_request.scatter(), + request.embedded_computation_versions(1)}; + computations.push_back(scatter_versioned_handle); + break; + } + + case OpRequest::kWhileRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const WhileRequest& while_request = request.request().while_request(); + const VersionedComputationHandle condition_versioned_handle = { + while_request.condition(), + request.embedded_computation_versions(0)}; + computations.push_back(condition_versioned_handle); + const VersionedComputationHandle body_versioned_handle = { + while_request.body(), request.embedded_computation_versions(1)}; + computations.push_back(body_versioned_handle); + break; + } + + default: + // No embedded computation. + break; + } + } + } + return computations; +} + +Status UserComputation::RemapEmbeddedComputations( + const std::map& old_to_new) { + auto update = [&old_to_new](ComputationHandle* to_update) -> Status { + int64 old = to_update->handle(); + auto it = old_to_new.find(old); + if (it == old_to_new.end()) { + string mapping = tensorflow::str_util::Join( + old_to_new, ", ", + [](string* out, std::pair element) { + tensorflow::strings::Appendf(out, "%lld:%lld", element.first, + element.second.handle()); + }); + return NotFound( + "could not find referenced (old) computation handle in mapping: " + "%lld; mapping: {%s}", + old, mapping.c_str()); + } + VLOG(2) << "remapping " << old << " to " << it->second.handle(); + *to_update = it->second; + return Status::OK(); + }; + TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); + for (auto& handle_request : *session_computation_.mutable_requests()) { + OperationRequest& request = handle_request.second; + switch (request.request().op_case()) { + case OpRequest::kCallRequest: { + TF_RET_CHECK(1 == request.embedded_computation_versions_size()); + CallRequest* call_request = + request.mutable_request()->mutable_call_request(); + TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); + break; + } + case OpRequest::kMapRequest: { + TF_RET_CHECK(1 == request.embedded_computation_versions_size()); + MapRequest* map_request = + request.mutable_request()->mutable_map_request(); + TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); + break; + } + case OpRequest::kReduceRequest: { + TF_RET_CHECK(1 == request.embedded_computation_versions_size()); + ReduceRequest* reduce_request = + request.mutable_request()->mutable_reduce_request(); + TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); + break; + } + case OpRequest::kReduceWindowRequest: { + TF_RET_CHECK(1 == request.embedded_computation_versions_size()); + ReduceWindowRequest* reduce_window_request = + request.mutable_request()->mutable_reduce_window_request(); + TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); + break; + } + case OpRequest::kSelectAndScatterRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + SelectAndScatterRequest* select_and_scatter_request = + request.mutable_request()->mutable_select_and_scatter_request(); + TF_RETURN_IF_ERROR( + update(select_and_scatter_request->mutable_select())); + TF_RETURN_IF_ERROR( + update(select_and_scatter_request->mutable_scatter())); + break; + } + case OpRequest::kWhileRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + WhileRequest* while_request = + request.mutable_request()->mutable_while_request(); + TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); + TF_RETURN_IF_ERROR(update(while_request->mutable_body())); + break; + } + default: + // No embedded computation. + TF_RET_CHECK(0 == request.embedded_computation_versions_size()); + break; + } + } + return Status::OK(); +} + +SessionComputation UserComputation::CloneSessionComputation( + VersionedComputationHandle::Version version) const { + tensorflow::mutex_lock lock(mutex_); + SessionComputation result = session_computation_; + // Erase all the requests that exceed the version specified. + // There's no lower_bound method on tensorflow::protobuf::Map so we iterate + // all the elements. + auto it = result.mutable_requests()->begin(); + while (it != result.mutable_requests()->end()) { + if (it->first > version) { + it = result.mutable_requests()->erase(it); + } else { + ++it; + } + } + return result; +} + +StatusOr UserComputation::LookupRequest( + const ComputationDataHandle& handle) const { + int64 handle_value = handle.handle(); + if (session_computation_.requests().count(handle_value) == 0) { + return InvalidArgument("no ComputationDataHandle value %lld", handle_value); + } + return &session_computation_.requests().at(handle_value); +} + +Status UserComputation::CheckParametersAreContiguous( + VersionedComputationHandle::Version version) const { + TF_RET_CHECK(version > 0 && version < next_handle_value_); + + // Determine number of parameter inputs at the given version. + std::map parameter_requests; + for (int64 request_num = 1; request_num <= version; ++request_num) { + const OperationRequest& request = + session_computation_.requests().at(request_num); + + if (request.request().op_case() == OpRequest::kParameterRequest) { + const ParameterRequest& parameter_request = + request.request().parameter_request(); + // Duplicate parameters should be checked when parameter requests are + // added. + TF_RET_CHECK(0 == + parameter_requests.count(parameter_request.parameter())); + parameter_requests[parameter_request.parameter()] = ¶meter_request; + } + } + + auto program_shape = MakeUnique(); + for (int64 i = 0; i < parameter_requests.size(); ++i) { + auto it = parameter_requests.find(i); + if (it == parameter_requests.end()) { + return FailedPrecondition( + "computation %s does not have all its parameters populated " + "sequentially, missing parameter %lld", + name_.c_str(), i); + } + } + + return Status::OK(); +} + +namespace { + +// Helper class which builds an HLO computation from a SessionComputation. To +// construct the HLO computation, the SessionComputation graph is walked in +// DFS order lowering each OperationRequest to an HLO instruction. +class ComputationLowerer { + public: + static std::unique_ptr Lower( + const string& computation_name, + const SessionComputation& session_computation, + VersionedComputationHandle::Version version, + UserComputation::HloComputationResolver hlo_resolver, + bool include_unused_parameters) { + ComputationLowerer lowerer(computation_name, session_computation, version, + std::move(hlo_resolver)); + return lowerer.Lower(include_unused_parameters); + } + + private: + ComputationLowerer(const string& computation_name, + const SessionComputation& session_computation, + VersionedComputationHandle::Version version, + UserComputation::HloComputationResolver hlo_resolver) + : hlo_builder_(computation_name), + session_computation_(session_computation), + version_(version), + hlo_resolver_(std::move(hlo_resolver)) {} + + // Build an HLO computation from the SessionComputation at the given + // version. + std::unique_ptr Lower(bool include_unused_parameters); + + private: + // DFS visitor of the UserComputation operations which lowers the operations + // to HLO instructions. + HloInstruction* Visit(const ComputationDataHandle& handle, + std::map* visited); + + // Resolves a ComputationHandle and Version to a previously lowered + // HloComputation using the hlo_resolver_ function. + HloComputation* ResolveComputation( + const ComputationHandle& handle, + VersionedComputationHandle::Version version); + + HloComputation::Builder hlo_builder_; + const SessionComputation& session_computation_; + const VersionedComputationHandle::Version version_; + const UserComputation::HloComputationResolver hlo_resolver_; +}; + +std::unique_ptr ComputationLowerer::Lower( + bool include_unused_parameters) { + // Map from ComputationDataHandle to HLO instruction. Serves as a record of + // which operations have been visited as well as a cache for looking up + // ComputationDataHandles as HloInstructions. + std::map visited; + + // A version is simply a ComputationDataHandle of the root of the computation + // at the time the version was generated. Create a ComputationDataHandle with + // this value and pass it to the visitor as the root of the computation to + // lower. + ComputationDataHandle root_handle; + root_handle.set_handle(version_); + + HloInstruction* hlo_root = Visit(root_handle, &visited); + + // A computation may have unused parameters. + if (include_unused_parameters) { + for (int64 request_num = 1; request_num <= version_; ++request_num) { + const OperationRequest& request = + session_computation_.requests().at(request_num); + if (request.request().op_case() == OpRequest::kParameterRequest && + visited.count(request.output_handle().handle()) == 0) { + Visit(request.output_handle(), &visited); + } + } + } + + // Add trace instructions. + for (const auto& trace_request : session_computation_.trace_requests()) { + if (trace_request.operand().handle() <= version_) { + HloInstruction* operand = visited[trace_request.operand().handle()]; + // Trace instructions cannot be the root of a computation. + HloInstruction* trace_instruction = hlo_builder_.AddInstruction( + HloInstruction::CreateTrace(trace_request.tag(), operand)); + operand->set_tracing(trace_instruction); + } + } + + // Send instructions do not have users, so they are not reachable from the + // root instruction. Therefore, explicitly visit all Send requests (and their + // operand chains) and add to the builder. + for (const auto& send_request : session_computation_.send_requests()) { + Visit(send_request.operand(), &visited); + HloInstruction* operand = visited[send_request.operand().handle()]; + HloInstruction* send_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateSend(operand)); + send_instruction->set_channel_id(send_request.channel_handle().handle()); + } + + return hlo_builder_.Build(hlo_root); +} + +HloComputation* ComputationLowerer::ResolveComputation( + const ComputationHandle& handle, + VersionedComputationHandle::Version version) { + const VersionedComputationHandle checked_handle = {handle, version}; + return hlo_resolver_(checked_handle); +} + +HloInstruction* ComputationLowerer::Visit( + const ComputationDataHandle& handle, + std::map* visited) { + if (visited->count(handle.handle()) != 0) { + return (*visited)[handle.handle()]; + } + + const OperationRequest& request = + session_computation_.requests().at(handle.handle()); + HloInstruction* hlo_instruction; + switch (request.request().op_case()) { + case OpRequest::kRngRequest: { + const RngRequest& rng_request = request.request().rng_request(); + std::vector parameters; + for (const ComputationDataHandle& param : rng_request.parameter()) { + parameters.push_back(Visit(param, visited)); + } + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng( + request.output_shape(), rng_request.distribution(), parameters)); + break; + } + + case OpRequest::kConstantRequest: { + const ConstantRequest& constant_request = + request.request().constant_request(); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CloneToUnique(constant_request.literal()))); + break; + } + + case OpRequest::kGetTupleElementRequest: { + const GetTupleElementRequest& get_tuple_element_request = + request.request().get_tuple_element_request(); + HloInstruction* operand = + Visit(get_tuple_element_request.operand(), visited); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement( + request.output_shape(), operand, + get_tuple_element_request.index())); + break; + } + + case OpRequest::kSliceRequest: { + const SliceRequest& slice_request = request.request().slice_request(); + HloInstruction* operand = Visit(slice_request.operand(), visited); + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice( + request.output_shape(), operand, + AsInt64Slice(slice_request.start_indices()), + AsInt64Slice(slice_request.limit_indices()))); + break; + } + + case OpRequest::kDynamicSliceRequest: { + const DynamicSliceRequest& dynamic_slice_request = + request.request().dynamic_slice_request(); + HloInstruction* operand = Visit(dynamic_slice_request.operand(), visited); + HloInstruction* start_indices = + Visit(dynamic_slice_request.start_indices(), visited); + + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice( + request.output_shape(), operand, start_indices, + AsInt64Slice(dynamic_slice_request.slice_sizes()))); + break; + } + + case OpRequest::kDynamicUpdateSliceRequest: { + const DynamicUpdateSliceRequest& dynamic_update_slice_request = + request.request().dynamic_update_slice_request(); + HloInstruction* operand = + Visit(dynamic_update_slice_request.operand(), visited); + HloInstruction* update = + Visit(dynamic_update_slice_request.update(), visited); + HloInstruction* start_indices = + Visit(dynamic_update_slice_request.start_indices(), visited); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + request.output_shape(), operand, update, start_indices)); + break; + } + + case OpRequest::kConcatenateRequest: { + const ConcatenateRequest& concatenate_request = + request.request().concatenate_request(); + std::vector operands; + for (const ComputationDataHandle& handle : + concatenate_request.operands()) { + HloInstruction* operand = Visit(handle, visited); + operands.push_back(operand); + } + hlo_instruction = hlo_builder_.AddInstruction( + HloInstruction::CreateConcatenate(request.output_shape(), operands, + concatenate_request.dimension())); + break; + } + + case OpRequest::kConvolveRequest: { + const ConvolveRequest& convolve_request = + request.request().convolve_request(); + HloInstruction* lhs = Visit(convolve_request.lhs(), visited); + HloInstruction* rhs = Visit(convolve_request.rhs(), visited); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateConvolve( + request.output_shape(), lhs, rhs, convolve_request.window(), + convolve_request.dimension_numbers())); + break; + } + + case OpRequest::kCrossReplicaSumRequest: { + const CrossReplicaSumRequest& cross_replica_sum_request = + request.request().cross_replica_sum_request(); + HloInstruction* operand = + Visit(cross_replica_sum_request.operand(), visited); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum( + request.output_shape(), operand)); + break; + } + + case OpRequest::kInfeedRequest: { + hlo_instruction = hlo_builder_.AddInstruction( + HloInstruction::CreateInfeed(request.output_shape())); + break; + } + + case OpRequest::kMapRequest: { + const MapRequest& map_request = request.request().map_request(); + std::vector operands; + for (const ComputationDataHandle& handle : map_request.operands()) { + HloInstruction* operand = Visit(handle, visited); + operands.push_back(operand); + } + CHECK_EQ(1, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version map_version = + request.embedded_computation_versions(0); + HloComputation* map_computation = + ResolveComputation(map_request.to_apply(), map_version); + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap( + request.output_shape(), operands, map_computation)); + break; + } + + case OpRequest::kReduceRequest: { + const ReduceRequest& reduce_request = request.request().reduce_request(); + HloInstruction* operand = Visit(reduce_request.operand(), visited); + HloInstruction* init_value = Visit(reduce_request.init_value(), visited); + CHECK_EQ(1, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version reduce_version = + request.embedded_computation_versions(0); + HloComputation* reduce_computation = + ResolveComputation(reduce_request.to_apply(), reduce_version); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateReduce( + request.output_shape(), operand, init_value, + AsInt64Slice(reduce_request.dimensions()), reduce_computation)); + break; + } + + case OpRequest::kReduceWindowRequest: { + const ReduceWindowRequest& reduce_window_request = + request.request().reduce_window_request(); + HloInstruction* operand = Visit(reduce_window_request.operand(), visited); + HloInstruction* init_value = + Visit(reduce_window_request.init_value(), visited); + CHECK_EQ(1, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version reduce_window_version = + request.embedded_computation_versions(0); + HloComputation* reduce_window_computation = ResolveComputation( + reduce_window_request.to_apply(), reduce_window_version); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow( + request.output_shape(), operand, init_value, + reduce_window_request.window(), reduce_window_computation)); + break; + } + + case OpRequest::kSelectAndScatterRequest: { + const SelectAndScatterRequest& select_and_scatter_request = + request.request().select_and_scatter_request(); + HloInstruction* operand = + Visit(select_and_scatter_request.operand(), visited); + HloInstruction* source = + Visit(select_and_scatter_request.source(), visited); + HloInstruction* init_value = + Visit(select_and_scatter_request.init_value(), visited); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version select_version = + request.embedded_computation_versions(0); + VersionedComputationHandle::Version scatter_version = + request.embedded_computation_versions(1); + HloComputation* select_computation = ResolveComputation( + select_and_scatter_request.select(), select_version); + HloComputation* scatter_computation = ResolveComputation( + select_and_scatter_request.scatter(), scatter_version); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter( + request.output_shape(), operand, select_computation, + select_and_scatter_request.window(), source, init_value, + scatter_computation)); + break; + } + + case OpRequest::kBroadcastRequest: { + const BroadcastRequest& broadcast_request = + request.request().broadcast_request(); + HloInstruction* operand = Visit(broadcast_request.operand(), visited); + std::vector broadcast_dimensions; + // The client-level broadcast instruction just appends dimensions on the + // left (adds lowest numbered dimensions). The HLO broadcast op is more + // flexible and can add new dimensions anywhere. The broadcast_dimensions + // maps operand dimensions to dimensions in the broadcast output, so + // to append dimensions on the left the broadcast_dimensions should just + // be the n highest dimension numbers of the output shape where n is + // the number of input dimensions. + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + broadcast_dimensions.push_back(i + + ShapeUtil::Rank(request.output_shape()) - + ShapeUtil::Rank(operand->shape())); + } + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + request.output_shape(), operand, broadcast_dimensions)); + break; + } + + case OpRequest::kReshapeRequest: { + const ReshapeRequest& reshape_request = + request.request().reshape_request(); + HloInstruction* operand = Visit(reshape_request.operand(), visited); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateReshape( + request.output_shape(), + hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation( + AsInt64Slice(reshape_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(reshape_request.dimensions()))))); + break; + } + + case OpRequest::kReverseRequest: { + const ReverseRequest& reverse_request = + request.request().reverse_request(); + HloInstruction* operand = Visit(reverse_request.operand(), visited); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateReverse( + request.output_shape(), operand, + AsInt64Slice(reverse_request.dimensions()))); + break; + } + + case OpRequest::kPadRequest: { + const PadRequest& pad_request = request.request().pad_request(); + HloInstruction* operand = Visit(pad_request.operand(), visited); + HloInstruction* padding_value = + Visit(pad_request.padding_value(), visited); + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad( + request.output_shape(), operand, padding_value, + pad_request.padding_config())); + break; + } + + case OpRequest::kRecvRequest: { + const RecvRequest& recv_request = request.request().recv_request(); + hlo_instruction = hlo_builder_.AddInstruction( + HloInstruction::CreateRecv(request.output_shape())); + hlo_instruction->set_channel_id(recv_request.channel_handle().handle()); + break; + } + + case OpRequest::kParameterRequest: { + const ParameterRequest& parameter_request = + request.request().parameter_request(); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateParameter( + parameter_request.parameter(), request.output_shape(), + parameter_request.name())); + break; + } + + case OpRequest::kConvertRequest: { + const ConvertRequest& convert_request = + request.request().convert_request(); + HloInstruction* operand = Visit(convert_request.operand(), visited); + hlo_instruction = hlo_builder_.AddInstruction( + HloInstruction::CreateConvert(request.output_shape(), operand)); + break; + } + + case OpRequest::kWhileRequest: { + const WhileRequest& while_request = request.request().while_request(); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version condition_version = + request.embedded_computation_versions(0); + HloComputation* condition = + ResolveComputation(while_request.condition(), condition_version); + VersionedComputationHandle::Version body_version = + request.embedded_computation_versions(1); + HloComputation* body = + ResolveComputation(while_request.body(), body_version); + HloInstruction* init = Visit(while_request.init(), visited); + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile( + request.output_shape(), condition, body, init)); + break; + } + + case OpRequest::kTernaryOpRequest: { + const TernaryOpRequest& ternary_op_request = + request.request().ternary_op_request(); + HloInstruction* lhs = Visit(ternary_op_request.lhs(), visited); + HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited); + HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited); + auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateTernary( + request.output_shape(), hlo_opcode, lhs, rhs, ehs)); + break; + } + + case OpRequest::kVariadicOpRequest: { + const VariadicOpRequest& variadic_op_request = + request.request().variadic_op_request(); + std::vector operands; + for (const ComputationDataHandle& handle : + variadic_op_request.operands()) { + HloInstruction* operand = Visit(handle, visited); + operands.push_back(operand); + } + auto hlo_opcode = + VariadicOperationToHloOpcode(variadic_op_request.varop()); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateVariadic( + request.output_shape(), hlo_opcode, operands)); + break; + } + + case OpRequest::kCallRequest: { + const CallRequest& call_request = request.request().call_request(); + std::vector operands; + for (const ComputationDataHandle& handle : call_request.operands()) { + operands.push_back(Visit(handle, visited)); + } + CHECK_EQ(1, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version call_version = + request.embedded_computation_versions(0); + HloComputation* call_computation = + ResolveComputation(call_request.to_apply(), call_version); + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall( + request.output_shape(), operands, call_computation)); + break; + } + + case OpRequest::kCustomCallRequest: { + const CustomCallRequest& cc_request = + request.request().custom_call_request(); + std::vector operands; + for (const ComputationDataHandle& operand : cc_request.operands()) { + operands.push_back(Visit(operand, visited)); + } + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall( + cc_request.shape(), operands, cc_request.call_target_name())); + break; + } + + case OpRequest::kUnaryOpRequest: { + const UnaryOpRequest& unary_op_request = + request.request().unary_op_request(); + HloInstruction* operand = Visit(unary_op_request.operand(), visited); + auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); + hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary( + request.output_shape(), hlo_opcode, operand)); + break; + } + + case OpRequest::kBinaryOpRequest: { + const BinaryOpRequest& binary_op_request = + request.request().binary_op_request(); + HloInstruction* lhs = Visit(binary_op_request.lhs(), visited); + HloInstruction* rhs = Visit(binary_op_request.rhs(), visited); + auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); + if (binary_op_request.broadcast_dimensions_size() > 0) { + // Emit a broadcast instruction to perform the "broadcast in dimension" + // operation. + CHECK_NE(ShapeUtil::Rank(lhs->shape()), ShapeUtil::Rank(rhs->shape())); + HloInstruction* operand_to_broadcast = + ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs + : rhs; + Shape broadcast_shape = ShapeUtil::MakeShape( + operand_to_broadcast->shape().element_type(), + AsInt64Slice(request.output_shape().dimensions())); + + CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), + binary_op_request.broadcast_dimensions().size()); + // The broadcast semantics of a client-level binary op broadcast is + // identical to the HLO broadcast semantics so the broadcast_dimensions + // field can just be passed to the instruction builder. + HloInstruction* broadcasted_operand = + hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, operand_to_broadcast, + AsInt64Slice(binary_op_request.broadcast_dimensions()))); + + lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; + rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; + } + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateBinary( + request.output_shape(), hlo_opcode, lhs, rhs)); + break; + } + + case OpRequest::OP_NOT_SET: + LOG(FATAL) << "OperationRequest doesn't contain a request"; + + default: + LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); + } + (*visited)[handle.handle()] = hlo_instruction; + return hlo_instruction; +} + +} // namespace + +StatusOr> UserComputation::BuildHloComputation( + VersionedComputationHandle::Version version, + HloComputationResolver hlo_resolver, bool include_unused_parameters) const { + tensorflow::mutex_lock lock(mutex_); + + VLOG(2) << "Building HloComputation from UserComputation " << name_ + << " at version " << version << ". Operation requests:\n" + << session_computation_.ShortDebugString(); + + std::unique_ptr hlo_computation = ComputationLowerer::Lower( + tensorflow::strings::StrCat(name(), ".v", version), session_computation_, + version, std::move(hlo_resolver), include_unused_parameters); + + VLOG(2) << "HloComputation:\n" << hlo_computation->ToString(); + return std::move(hlo_computation); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h new file mode 100644 index 0000000000..06824b01c7 --- /dev/null +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -0,0 +1,336 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A UserComputation is the built-up computation that users create via the +// XLA Service interface. +// +// The XLA service adds instructions to a user computation via this +// interface. The state of the computation is stored as a SessionComputation +// proto which holds a record of all operation-building requests received by the +// XLA service. +// +// UserComputations are lowered to HloComputations which are passed to the high +// level compiler interface. +class UserComputation { + public: + // Factory used when restoring a computation from serialized session + // computation (computation snapshot) data. Remaps any references to + // computation handle via the old_to_new mapping. + // + // An error will occur if the old_to_new mapping cannot resolve a reference to + // a computation that is present in session_computation. + static StatusOr> MakeWithRemapping( + const SessionComputation& session_computation, + const ComputationHandle& handle, + const std::map& old_to_new); + + // Creates an empty computation with the given name and computation handle. + explicit UserComputation(const string& name, const ComputationHandle& handle); + + // Enqueues a parameter-retrieving instruction onto this user computation. + // Returns an error status if the parameter number is already registered with + // different values. + StatusOr AddParameterInstruction( + const ParameterRequest& parameter_request); + + // Enqueues a pad instruction onto this user computation. + StatusOr AddPadInstruction( + const PadRequest& parameter_request); + + // Enqueues a tracing instruction onto this user computation. + // Returns an error status if the operand cannot be resolved. + Status AddTraceInstruction(const TraceRequest& trace_request); + + // Enqueues a random number generation instruction onto this user computation. + StatusOr AddRngInstruction( + const RngRequest& rng_request); + + // Enqueues a unary instruction onto this user computation. + // Returns an error status if the operand index is out of bounds. + StatusOr AddUnaryInstruction( + const UnaryOpRequest& unary_request); + + // Enqueues a binary instruction onto this user computation. + // Returns an error status if the operand indices are out of bounds. + StatusOr AddBinaryInstruction( + const BinaryOpRequest& binary_request); + + // Enqueues a ternary instruction onto this user computation. + // Returns an error status if the operand indices are out of bounds. + StatusOr AddTernaryInstruction( + const TernaryOpRequest& request); + + // Enqueues a variadic instruction onto this user computation. + // Returns an error status if the operand indices are out of bounds. + StatusOr AddVariadicInstruction( + const VariadicOpRequest& variadic_request); + + // Enqueues a constant instruction onto this user computation. + StatusOr AddConstantInstruction( + const ConstantRequest& constant_request); + + // Enqueues a get tuple element instruction onto this user computation. + StatusOr AddGetTupleElementInstruction( + const GetTupleElementRequest& get_tuple_element_request); + + // Enqueues a map instruction onto this user computation. + StatusOr AddMapInstruction( + const MapRequest& map_request, + const UserComputation& to_apply_computation); + + // Enqueues a convolution instruction onto this user computation. + StatusOr AddConvolveInstruction( + const ConvolveRequest& convolve_request); + + // Enqueues a cross replica sum instruction onto this user computation. + StatusOr AddCrossReplicaSumInstruction( + const CrossReplicaSumRequest& cross_replica_sum_request); + + // Enqueues an infeed instruction onto this user computation. + StatusOr AddInfeedInstruction( + const InfeedRequest& infeed_request); + + // Enqueues a call instruction onto this user computation. + StatusOr AddCallInstruction( + const CallRequest& call_request, + const UserComputation& to_apply_computation); + + // Enqueues a custom call instruction onto this user computation. + StatusOr AddCustomCallInstruction( + const CustomCallRequest& custom_call_request); + + // Enqueues a broadcast instruction onto this user computation. + StatusOr AddBroadcastInstruction( + const BroadcastRequest& broadcast_request); + + // Enqueues a reshape instruction onto this user computation. + StatusOr AddReshapeInstruction( + const ReshapeRequest& reshape_request); + + // Enqueues a slice instruction onto this user computation. + StatusOr AddSliceInstruction( + const SliceRequest& slice_request); + + // Enqueues a dynamic slice instruction onto this user computation. + StatusOr AddDynamicSliceInstruction( + const DynamicSliceRequest& dynamic_slice_request); + + // Enqueues a dynamic update slice instruction onto this user computation. + StatusOr AddDynamicUpdateSliceInstruction( + const DynamicUpdateSliceRequest& dynamic_update_slice_request); + + // Enqueues a concatenate instruction onto this user computation. + StatusOr AddConcatenateInstruction( + const ConcatenateRequest& slice_request); + + // Enqueues a convert instruction onto this user computation. + StatusOr AddConvertInstruction( + const ConvertRequest& convert_request); + + // Enqueues a reduce instruction onto this user computation. + StatusOr AddReduceInstruction( + const ReduceRequest& reduce_request, + const UserComputation& reduction_computation); + + // Enqueues a windowed reduce instruction onto this user computation. + StatusOr AddReduceWindowInstruction( + const ReduceWindowRequest& reduce_window_request, + const UserComputation& reduction_computation); + + // Enqueues a select-and-scatter instruction onto this user + // computation. + StatusOr AddSelectAndScatterInstruction( + const SelectAndScatterRequest& scatter_to_selected_window_element_request, + const UserComputation& select_computation, + const UserComputation& scatter_computation); + + // Enqueues a reverse instruction onto this user computation. + StatusOr AddReverseInstruction( + const ReverseRequest& reverse_request); + + // Enqueues a while instruction onto this user computation. + StatusOr AddWhileInstruction( + const WhileRequest& while_request, + const UserComputation& condition_computation, + const UserComputation& body_computation); + + // Enqueues a Send instruction onto this user computation. + Status AddSendInstruction(const SendRequest& send_request); + + // Enqueues a Recv instruction onto this user computation. + StatusOr AddRecvInstruction( + const RecvRequest& recv_request); + + // Returns the user-provided name of this user computation, which is provided + // via the XLA computation-building API. + const string& name() const { return name_; } + + // Subsequent executions of this computation will compute the value + // represented by handle, rather than the last expression enqueued + // on the computation. + Status SetReturnValue(const ComputationDataHandle& handle); + + // Return a versioned handle for this computation. + VersionedComputationHandle GetVersionedHandle() const; + + // Return a versioned handle for this computation with a version equal to the + // point at which given operation was added to the computation. + VersionedComputationHandle GetVersionedHandleAtOperation( + const ComputationDataHandle& operation) const; + + // Return a version value representing the current state of the + // computation. + VersionedComputationHandle::Version version() const; + + // Computes and returns the program shape for the user computation -- gathers + // parameters and result type into a single proto. A shared_ptr is used + // because the returned pointer refers to an internally cached value which may + // be discarded by the UserComputation object. This avoid unnecessary copies. + // + // If the parameter space is not dense (i.e. there are holes in the parameter + // numbers provided) then an error status is returned. + StatusOr> ComputeProgramShape( + VersionedComputationHandle::Version version) const; + + // Returns true if the given data handle does not depend on any + // parameters. That is, the value can be computed at compile time. + StatusOr IsConstant(const ComputationDataHandle& handle); + + // Returns the output shape of the operation indicated by the given handle. + StatusOr GetShape(const ComputationDataHandle& handle); + + // Builds a HLO computation from the UserComputation. The parameter "resolver" + // is a function which returns a pointer to the HloComputation corresponding + // to the given ComputationHandle at the given version. The resolver is used + // for operations, such as map, which call other computations and need a + // pointer to the called HloComputation to construct the respective HLO + // instructions. If include_unused_computation is true, then all parameter + // instructions are lowered into HloInstructions even if the parameter is + // unused (the root of the computation is unreachable from the parameter). + using HloComputationResolver = + std::function; + StatusOr> BuildHloComputation( + VersionedComputationHandle::Version version, + HloComputationResolver hlo_resolver, + bool include_unused_parameters = true) const; + + // Return a vector containing the embedded computations used by this + // UserComputation. Only embedded computations which are called directly by + // this UserComputation are included. That is, the transitive closure of + // embedded computations is not included. + std::vector GetEmbeddedComputations( + VersionedComputationHandle::Version version) const; + + // Returns the number of OperationRequest objects in this UserComputation. + // The 'version' of a computation is identical to the number of + // OperationRequests in the UserComputation. + int64 request_count(VersionedComputationHandle::Version version) const { + return version; + } + + // Returns a copy of the internal session state for this computation -- this + // is useful for serializing the guts of a user computation, though references + // to other handles (e.g. referred-to computations) must be handled with care + // in the serialization / de-serialization process. + SessionComputation CloneSessionComputation( + VersionedComputationHandle::Version version) const; + + private: + // Warning: dangerous mutating operation that doesn't respect versioning. + // This is only used at initialization time when constructing from a + // SessionComputation a la MakeWithRemapping. + // + // Remaps references to old computations (with handle values in the keys of + // old_to_new) to the computation handle given in the values. This is useful + // when loading computations from snapshots, to finish initialization, before + // the user computation is released into the wild. + Status RemapEmbeddedComputations( + const std::map& old_to_new) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns the OperationRequestion corresponding to the root (result) of the + // computation. + const OperationRequest& GetRoot(VersionedComputationHandle::Version version) + const EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns the OperationRequest corresponding to the given handle value. + StatusOr LookupRequest( + const ComputationDataHandle& handle) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Creates a new ComputationDataHandle with the next available handle value. + ComputationDataHandle CreateComputationDataHandle() + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Checks whether the parameter numbers of the parameter operations are + // contiguous starting from zero. Returns appropriate error status if not. + Status CheckParametersAreContiguous( + VersionedComputationHandle::Version version) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Name of the computation. + string name_; + + mutable tensorflow::mutex mutex_; + + // State of the computation as a record of all operation-building requests. + SessionComputation session_computation_ GUARDED_BY(mutex_); + + // Mapping from parameter number to operation request containing the + // respective ParameterRequest. + std::map parameters_ GUARDED_BY(mutex_); + + // The next ComputationDataHandle value to assign. Handle values are assigned + // sequentially. + int64 next_handle_value_ GUARDED_BY(mutex_); + + // If handle_to_return_.has_handle() then an Execution of this Computation + // will compute the value represented by handle_to_return_, otherwise it will + // compute the value of (next_handle_value_ - 1). + ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); + + // Memoized ProgramShape and its version. A shared_ptr is used because + // references to this object are returned by ComputeProgramShape. + mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; + mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h new file mode 100644 index 0000000000..03bee3d4a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/versioned_computation_handle.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A data structure encapsulating a ComputationHandle and version value of that +// computation. This object is used to unambiguously refer to a particular +// computation in the service. +struct VersionedComputationHandle { + // A version value unambiguously specifying the state of the computation at a + // particular point in time as it is being built. This value is the + // ComputationDataHandle of the current root instruction. + using Version = int64; + + ComputationHandle handle; + Version version; + bool operator==(const VersionedComputationHandle& other) const { + return (handle.handle() == other.handle.handle()) && + (version == other.version); + } + bool operator<(const VersionedComputationHandle& other) const { + return ((handle.handle() < other.handle.handle()) || + ((handle.handle() == other.handle.handle()) && + (version < other.version))); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h new file mode 100644 index 0000000000..fc107480f7 --- /dev/null +++ b/tensorflow/compiler/xla/service_interface.h @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Defines the interface for an XLA service. +class ServiceInterface { + public: + ServiceInterface() {} + virtual ~ServiceInterface() = default; + + // TODO(b/31824348): Convert to use StatusOr. + virtual tensorflow::Status TransferToClient( + const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + + virtual tensorflow::Status TransferToClientInProcess( + const TransferToClientInProcessRequest* arg, + TransferToClientInProcessResponse* result) = 0; + + virtual tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + + virtual tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + + virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; + + virtual tensorflow::Status TransferToServerInProcess( + const TransferToServerInProcessRequest* arg, + TransferToServerInProcessResponse* result) = 0; + + virtual tensorflow::Status LoadComputationSnapshot( + const LoadComputationSnapshotRequest* request, + LoadComputationSnapshotResponse* result) = 0; + + virtual tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; + + virtual tensorflow::Status ExecuteParallel( + const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + + virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) = 0; + + virtual tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; + + virtual tensorflow::Status DeconstructTuple( + const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; + + virtual tensorflow::Status GetComputationStats( + const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + + virtual tensorflow::Status GetComputationShape( + const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) = 0; + + virtual tensorflow::Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; + + virtual tensorflow::Status CreateChannelHandle( + const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; + + virtual tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; + + // Methods used by ComputationBuilder. + virtual tensorflow::Status Computation(const ComputationRequest* arg, + ComputationResponse* result) = 0; + + virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; + + virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) = 0; + + virtual tensorflow::Status SetReturnValue( + const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; + + virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) = 0; + + virtual tensorflow::Status ComputeConstant( + const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + + // Methods used by Computation. + virtual tensorflow::Status SnapshotComputation( + const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) = 0; + + // Methods used by GlobalData. + virtual tensorflow::Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc new file mode 100644 index 0000000000..5bf9842a6c --- /dev/null +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape_layout.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { + if (!ShapeUtil::Compatible(other_shape, shape_)) { + return InvalidArgument("Shape %s is not compatible with shape %s", + ShapeUtil::HumanString(other_shape).c_str(), + ShapeUtil::HumanString(shape()).c_str()); + } + shape_ = other_shape; + return tensorflow::Status::OK(); +} + +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { + if (!ShapeUtil::Compatible(*other_shape, shape_)) { + return InvalidArgument("Shape %s is not compatible with shape %s", + ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(shape()).c_str()); + } + *other_shape = shape_; + return tensorflow::Status::OK(); +} + +void ShapeLayout::SetToDefaultLayout() { + LayoutUtil::SetToDefaultLayout(&shape_); +} + +bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { + return ShapeUtil::Equal(shape, shape_); +} + +const Layout& ShapeLayout::layout() const { + CHECK(LayoutIsSet()); + CHECK(!ShapeUtil::IsTuple(shape_)); + return shape_.layout(); +} + +void ShapeLayout::Clear() { LayoutUtil::ClearLayout(&shape_); } + +bool ShapeLayout::LayoutIsSet() const { return LayoutUtil::HasLayout(shape_); } + +void ShapeLayout::ResetLayout(const Layout& layout) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsOpaque(shape_)); + *shape_.mutable_layout() = layout; + TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); +} + +bool ShapeLayout::operator==(const ShapeLayout& other) const { + return ShapeUtil::Equal(shape_, other.shape_); +} + +bool ShapeLayout::operator!=(const ShapeLayout& other) const { + return !ShapeUtil::Equal(shape_, other.shape_); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h new file mode 100644 index 0000000000..92564660f2 --- /dev/null +++ b/tensorflow/compiler/xla/shape_layout.h @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_ + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// A ShapeLayout object encapsulates the layout of a particular shape (including +// tuples). This differs from the Layout proto which describes the layout of a +// single array. ShapeLayout contains a Layout proto for each array in the shape +// (a tuple can have more than one array). For array shapes, this object +// trivially holds a single Layout. Logically, ShapeLayout holds a nonmutable +// shape with mutable layouts. +class ShapeLayout { + public: + // Constructs a ShapeLayout of the given shape. Layouts are copied from the + // shape parameter. + explicit ShapeLayout(const Shape& shape) : shape_(shape) {} + + // Assigns the layouts in this ShapeLayout to the Layout fields of the given + // shape. 'shape' and the shape of the ShapeLayout object must be compatible. + tensorflow::Status AssignLayoutToShape(Shape* shape) const; + + // Returns true if the Layouts in this ShapeLayout match the layouts in the + // given shape. Returns false otherwise. If the given shape is not compatible + // with the ShapeLayout's shape, then false is returned. + bool MatchesLayoutInShape(const Shape& shape) const; + + // Copies the layout from the given shape into this ShapeLayout. 'shape' must + // be compatible with the ShapeLayout's shape, and 'shape' must have a layout + // (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& shape); + + // Clears (Layout::Clear) all the Layouts stored in this object. + void Clear(); + + // Sets all Layouts stored in this object to the default layout. + void SetToDefaultLayout(); + + // Returns the shape (with layouts). + const Shape& shape() const { return shape_; } + + // Checks that a layout is set for the shape, and returns a reference to the + // layout directly on the shape. Shape must not be a tuple. + const Layout& layout() const; + + // Returns true if all layouts have been set for this ShapeLayout object. That + // is, every array has a layout. + bool LayoutIsSet() const; + + // Resets the layout on the shape to the provided layout. Shape must not be a + // tuple. + void ResetLayout(const Layout& layout); + + // Returns a string representation of this object. + string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); } + + // Tests for equality of both shape and layout (ShapeUtil::Equal). + bool operator==(const ShapeLayout& other) const; + bool operator!=(const ShapeLayout& other) const; + + private: + Shape shape_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h new file mode 100644 index 0000000000..6963a68d10 --- /dev/null +++ b/tensorflow/compiler/xla/shape_tree.h @@ -0,0 +1,260 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A ShapeTree is a recursive data structure which mirrors the structure of a +// XLA shape and holds a value of type T for each array in the shape. For +// array shapes, a ShapeTree trivially holds a single value of type T. For tuple +// shapes which can be an arbitrary tree with arrays at the leaves, a ShapeTree +// is an identically structured tree with data elements of type T at the leaves. +// +// Like the Shape data structure, this is a tree and tuple elements cannot be +// duplicated. That is, every distinct element position in the Shape has a +// unique T object. +template +class ShapeTree { + public: + explicit ShapeTree(const Shape& shape); + ShapeTree(const Shape& shape, const T& init_value); + ShapeTree(const ShapeTree& other); + ShapeTree& operator=(const ShapeTree& other); + + // Returns the data element associated with the array in the shape at the + // given index (see ShapeUtil::GetSubshape for how indexes are defined). + const T& element(const ShapeIndex& index) const; + T* mutable_element(const ShapeIndex& index); + + // Return the shape represented with this ShapeTree. + const Shape& shape() const { return *shape_; } + + // Returns true if the node at the given index is a leaf node (an array + // shape). + bool IsLeaf(const ShapeIndex& index) const { + return Lookup(index).elements_.empty(); + } + + // Recursively traverses the shape and calls the given function at each + // element. The function has the following arguments: + // + // index : the index of the element in the shape. See ShapeUtil::GetSubshape + // for definition of index. + // is_leaf : Whether this element is a leaf element in the shape. That is, + // whether this index corresponds to an array and not a (nested) + // tuple element. + // data : The data value at this elemnt. + // + // If any call to the given function returns a non-OK status, then traversal + // is aborted and the status value is returned. + using VisitorFunction = std::function; + tensorflow::Status ForEachElement(VisitorFunction func) const; + + using MutableVisitorFunction = std::function; + tensorflow::Status ForEachMutableElement(MutableVisitorFunction func); + + private: + // Private default constructor for non-root nodes of the tree. + ShapeTree() = default; + + // Helpers for traversing the shape via ForEachElement. The helpers + // recursively traverse the subtree rooted at "index" (defined as in + // ShapeUtil::GetSubshape). + static tensorflow::Status ForEachHelperMutable(ShapeIndex* index, + ShapeTree* shape_tree, + MutableVisitorFunction func); + static tensorflow::Status ForEachHelper(ShapeIndex* index, + const ShapeTree& shape_tree, + VisitorFunction func); + + // Copy all the data elements (of type T) from "other" into "this". "this" + // must have the same tree structure as "other" prior to calling this method. + void CopyDataElements(const ShapeTree& other); + + // Recursive helper for constructing a subtree beneath "this" node. + void BuildTree(const Shape& shape); + + // Return the tree node at the given index. + ShapeTree& Lookup(const ShapeIndex& index); + const ShapeTree& Lookup(const ShapeIndex& index) const; + + // The data corresponding to the array at this node. + T data_; + + // The XLA shape mirrored in this ShapeTree. Only the root of the + // ShapeTree has this member set. + std::unique_ptr shape_; + + // The children of this node in the tree. + std::vector> elements_; +}; + +template +void ShapeTree::BuildTree(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements_.emplace_back(new ShapeTree()); + elements_.back()->BuildTree(shape.tuple_shapes(i)); + } + } +} + +template +ShapeTree::ShapeTree(const Shape& shape) : shape_(MakeUnique(shape)) { + // The shape_ field is just used to hold the structure of the shape. It should + // not be relied upon to store layout information. + LayoutUtil::ClearLayout(shape_.get()); + BuildTree(*shape_); +} + +template +ShapeTree::ShapeTree(const Shape& shape, const T& init_value) + : shape_(MakeUnique(shape)) { + LayoutUtil::ClearLayout(shape_.get()); + BuildTree(*shape_); + TF_CHECK_OK(ForEachMutableElement( + [&init_value](const ShapeIndex& /*index*/, bool /*is_leaf*/, bool* data) { + *data = init_value; + return tensorflow::Status::OK(); + })); +} + +template +ShapeTree::ShapeTree(const ShapeTree& other) + : shape_(MakeUnique(other.shape())) { + LayoutUtil::ClearLayout(shape_.get()); + BuildTree(*shape_); + CopyDataElements(other); +} + +template +ShapeTree& ShapeTree::operator=(const ShapeTree& other) { + if (this == &other) { + return *this; + } + elements_.clear(); + shape_ = MakeUnique(other.shape()); + LayoutUtil::ClearLayout(shape_.get()); + + BuildTree(*shape_); + CopyDataElements(other); + return *this; +} + +template +void ShapeTree::CopyDataElements(const ShapeTree& other) { + CHECK(ShapeUtil::Compatible(shape(), other.shape())); + TF_CHECK_OK(ForEachMutableElement( + [&other](const ShapeIndex& index, bool /*is_leaf*/, T* data) { + *data = other.element(index); + return tensorflow::Status::OK(); + })); +} + +template +const T& ShapeTree::element(const ShapeIndex& index) const { + return Lookup(index).data_; +} + +template +T* ShapeTree::mutable_element(const ShapeIndex& index) { + return &Lookup(index).data_; +} + +template +ShapeTree& ShapeTree::Lookup(const ShapeIndex& index) { + ShapeTree* node = this; + for (auto& i : index) { + CHECK_GE(i, 0); + CHECK_LT(i, node->elements_.size()); + node = node->elements_[i].get(); + } + return *node; +} + +template +const ShapeTree& ShapeTree::Lookup(const ShapeIndex& index) const { + return const_cast*>(this)->Lookup(index); +} + +/* static */ +template +tensorflow::Status ShapeTree::ForEachHelperMutable( + ShapeIndex* index, ShapeTree* shape_tree, + ShapeTree::MutableVisitorFunction func) { + TF_RETURN_IF_ERROR( + func(*index, shape_tree->elements_.empty(), &shape_tree->data_)); + for (int i = 0; i < shape_tree->elements_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachHelperMutable(index, shape_tree->elements_[i].get(), func)); + index->pop_back(); + } + + return tensorflow::Status::OK(); +} + +/* static */ +template +tensorflow::Status ShapeTree::ForEachHelper( + ShapeIndex* index, const ShapeTree& shape_tree, + ShapeTree::VisitorFunction func) { + TF_RETURN_IF_ERROR( + func(*index, shape_tree.elements_.empty(), shape_tree.data_)); + for (int i = 0; i < shape_tree.elements_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(index, *shape_tree.elements_[i], func)); + index->pop_back(); + } + + return tensorflow::Status::OK(); +} + +template +tensorflow::Status ShapeTree::ForEachElement( + ShapeTree::VisitorFunction func) const { + ShapeIndex index; + return ForEachHelper(&index, *this, func); +} + +template +tensorflow::Status ShapeTree::ForEachMutableElement( + ShapeTree::MutableVisitorFunction func) { + ShapeIndex index; + return ForEachHelperMutable(&index, this, func); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc new file mode 100644 index 0000000000..d37f536b75 --- /dev/null +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape_tree.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class ShapeTreeTest : public ::testing::Test { + protected: + ShapeTreeTest() { + array_shape_ = ShapeUtil::MakeShape(F32, {42, 42, 123}); + tuple_shape_ = + ShapeUtil::MakeTupleShape({array_shape_, array_shape_, array_shape_}); + nested_tuple_shape_ = ShapeUtil::MakeTupleShape( + {array_shape_, ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), + array_shape_})}); + } + + // An array shape (non-tuple). + Shape array_shape_; + + // A three element tuple shape. + Shape tuple_shape_; + + // A nested tuple shape of the following form: (a, (c, d), ((e, f), g)) + Shape nested_tuple_shape_; +}; + +TEST_F(ShapeTreeTest, ArrayShape) { + ShapeTree shape_tree{array_shape_}; + *shape_tree.mutable_element({}) = 42; + EXPECT_EQ(42, shape_tree.element({})); + *shape_tree.mutable_element({}) = 123; + EXPECT_EQ(123, shape_tree.element({})); + + EXPECT_TRUE(ShapeUtil::Compatible(array_shape_, shape_tree.shape())); + + // Test the copy constructor. + ShapeTree copy{shape_tree}; + EXPECT_EQ(123, copy.element({})); +} + +TEST_F(ShapeTreeTest, TupleShape) { + ShapeTree shape_tree{tuple_shape_}; + *shape_tree.mutable_element({}) = 1; + *shape_tree.mutable_element({0}) = 42; + *shape_tree.mutable_element({1}) = 123; + *shape_tree.mutable_element({2}) = -100; + EXPECT_EQ(1, shape_tree.element({})); + EXPECT_EQ(42, shape_tree.element({0})); + EXPECT_EQ(123, shape_tree.element({1})); + EXPECT_EQ(-100, shape_tree.element({2})); + + EXPECT_TRUE(ShapeUtil::Compatible(tuple_shape_, shape_tree.shape())); + + // Sum all elements in the shape. + int sum = 0; + TF_CHECK_OK(shape_tree.ForEachElement( + [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) { + sum += data; + return tensorflow::Status::OK(); + })); + EXPECT_EQ(66, sum); + + // Test the copy constructor. + ShapeTree copy{shape_tree}; + EXPECT_EQ(1, copy.element({})); + EXPECT_EQ(42, copy.element({0})); + EXPECT_EQ(123, copy.element({1})); + EXPECT_EQ(-100, copy.element({2})); + + // Write zero to all data elements. + TF_CHECK_OK(shape_tree.ForEachMutableElement( + [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int* data) { + *data = 0; + return tensorflow::Status::OK(); + })); + EXPECT_EQ(0, shape_tree.element({})); + EXPECT_EQ(0, shape_tree.element({0})); + EXPECT_EQ(0, shape_tree.element({1})); + EXPECT_EQ(0, shape_tree.element({2})); +} + +TEST_F(ShapeTreeTest, NestedTupleShape) { + ShapeTree shape_tree{nested_tuple_shape_}; + *shape_tree.mutable_element({0}) = 42; + *shape_tree.mutable_element({1, 1}) = 123; + *shape_tree.mutable_element({2, 0, 1}) = -100; + EXPECT_EQ(42, shape_tree.element({0})); + EXPECT_EQ(123, shape_tree.element({1, 1})); + EXPECT_EQ(-100, shape_tree.element({2, 0, 1})); + + EXPECT_TRUE(ShapeUtil::Compatible(nested_tuple_shape_, shape_tree.shape())); + + // Test the copy constructor. + ShapeTree copy{shape_tree}; + EXPECT_EQ(42, copy.element({0})); + EXPECT_EQ(123, copy.element({1, 1})); + EXPECT_EQ(-100, copy.element({2, 0, 1})); +} + +TEST_F(ShapeTreeTest, InvalidIndexingTuple) { + ShapeTree shape_tree{tuple_shape_}; + + EXPECT_DEATH(shape_tree.element({4}), ""); +} + +TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { + ShapeTree shape_tree{nested_tuple_shape_}; + + EXPECT_DEATH(shape_tree.element({0, 0}), ""); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc new file mode 100644 index 0000000000..a8878e7941 --- /dev/null +++ b/tensorflow/compiler/xla/shape_util.cc @@ -0,0 +1,1024 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape_util.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { + +/* static */ bool ShapeUtil::CompareShapes(const Shape& lhs, const Shape& rhs, + bool compare_layouts) { + if (IsTuple(lhs)) { + return IsTuple(rhs) && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { + return CompareShapes(l, r, compare_layouts); + }); + } + // Explicitly compare the fields rather than using MessageDifferencer because + // we want empty layouts to be treated identically to missing layouts. + if (compare_layouts && + (!ContainersEqual(lhs.layout().minor_to_major(), + rhs.layout().minor_to_major()) || + !ContainersEqual(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions()) || + lhs.layout().padding_value() != rhs.layout().padding_value())) { + return false; + } + return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); +} + +/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); + if (!equal && VLOG_IS_ON(3)) { + // TODO(jeff): Maybe print more info about where lhs and rhs differ + VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() + << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + +/* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { + int64 accum = 0; + for (int64 dimension : shape.dimensions()) { + // We do not count zero dimensions. + if (dimension != 1) { + accum += 1; + } + } + return accum; +} + +/* static */ ProgramShape ShapeUtil::MakeProgramShape( + std::initializer_list parameters, Shape result) { + ProgramShape program_shape; + for (const auto& shape : parameters) { + *program_shape.add_parameters() = shape; + } + *program_shape.mutable_result() = result; + return program_shape; +} + +/* static */ Shape ShapeUtil::MakeShape( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { + DCHECK_NE(TUPLE, element_type); + DCHECK_NE(OPAQUE, element_type); + Shape result; + PopulateShape(element_type, dimensions, &result); + return result; +} + +/* static */ Shape ShapeUtil::MakeShapeWithLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + CHECK_EQ(dimensions.size(), minor_to_major.size()); + Shape shape = MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + DCHECK(shape.has_layout()); + TF_DCHECK_OK(ValidateShape(shape)); + return shape; +} + +/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { + std::vector layout(dimensions.size()); + std::iota(layout.rbegin(), layout.rend(), static_cast(0)); + return MakeShapeWithLayout(element_type, dimensions, layout); +} + +/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout( + const Shape& shape) { + std::vector dims(shape.dimensions_size()); + for (int i = 0; i < shape.dimensions_size(); ++i) { + dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); +} +/* static */ void ShapeUtil::PopulateShape( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + Shape* shape) { + shape->Clear(); + shape->set_element_type(element_type); + for (int64 dimension : dimensions) { + shape->add_dimensions(dimension); + } + LayoutUtil::SetToDefaultLayout(shape); + TF_DCHECK_OK(ValidateShape(*shape)); +} + +/* static */ Shape ShapeUtil::MakeTupleShape( + tensorflow::gtl::ArraySlice shapes) { + Shape result; + result.set_element_type(TUPLE); + for (const auto& shape : shapes) { + AppendShapeToTuple(shape, &result); + } + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + +/* static */ Shape ShapeUtil::MakeOpaqueShape() { + Shape result; + result.set_element_type(OPAQUE); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + +/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, + Shape* tuple_shape) { + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); + *tuple_shape->add_tuple_shapes() = shape; +} + +/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { + shape->mutable_layout()->add_minor_to_major(ShapeUtil::Rank(*shape)); + shape->add_dimensions(bound); + TF_DCHECK_OK(ValidateShape(*shape)); +} + +/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) { + return primitive_util::IsIntegralType(shape.element_type()); +} + +/* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape, + int32 bits) { + return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits); +} + +/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { + if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + return false; + } + return primitive_util::BitWidth(shape.element_type()) == bits; +} + +/* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) { + switch (shape.element_type()) { + case S8: + case S16: + case S32: + case S64: + case F16: + case F32: + case F64: + return true; + + case PRED: + case U8: + case U16: + case U32: + case U64: + case TUPLE: + case OPAQUE: + return false; + + default: + LOG(FATAL) << "Unhandled element type " << shape.element_type(); + } +} + +/* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) { + return primitive_util::IsFloatingPointType(shape.element_type()); +} + +/* static */ bool ShapeUtil::IsTuple(const Shape& shape) { + return shape.element_type() == TUPLE; +} + +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return !IsTuple(shape) && !IsOpaque(shape); +} + +/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { + return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), + shape.tuple_shapes().end(), IsTuple); +} + +/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { + return IsTuple(shape) && TupleElementCount(shape) == 0; +} + +/* static */ bool ShapeUtil::IsNil(const Shape& shape) { + return IsEmptyTuple(shape) || HasZeroElements(shape); +} + +/* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { + CHECK(IsTuple(shape)); + return shape.tuple_shapes_size(); +} + +/* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape, + int64 index) { + CHECK(IsTuple(shape)); + CHECK_GT(TupleElementCount(shape), index); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index))); + return shape.tuple_shapes(index); +} + +/* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, + int64 limit) { + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); + CHECK(IsTuple(tuple)); + CHECK_LE(start, TupleElementCount(tuple)); + CHECK_LE(limit, TupleElementCount(tuple)); + + std::vector new_elements(tuple.tuple_shapes().begin() + start, + tuple.tuple_shapes().begin() + limit); + return ShapeUtil::MakeTupleShape(new_elements); +} + +/* static */ bool ShapeUtil::IsOpaque(const Shape& shape) { + return shape.element_type() == OPAQUE; +} + +/* static */ bool ShapeUtil::ShapeIs(const Shape& shape, + PrimitiveType element_type, + std::initializer_list dimensions) { + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); + if (shape.element_type() != element_type) { + return false; + } + if (shape.dimensions_size() != ShapeUtil::Rank(shape)) { + return false; + } + int64 i = 0; + for (int64 dimension : dimensions) { + if (shape.dimensions(i) != dimension) { + return false; + } + i += 1; + } + return true; +} + +/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { + CHECK_EQ(shape.dimensions_size(), ShapeUtil::Rank(shape)); + return std::accumulate( + shape.dimensions().begin(), shape.dimensions().end(), 1LL, + std::multiplies()); +} + +/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) { + return ElementsIn(shape) == 0; +} + +/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { + return shape.element_type() == F32 && ShapeUtil::Rank(shape) == 0; +} + +/* static */ string ShapeUtil::HumanString(const Shape& shape) { + if (shape.element_type() == TUPLE) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); + prefix = ", "; + } + text += ")"; + return text; + } else { + return tensorflow::strings::StrCat( + tensorflow::str_util::Lowercase( + PrimitiveType_Name(shape.element_type())), + "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); + } +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (shape.element_type() == TUPLE) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + tensorflow::strings::StrAppend(&text, prefix, + HumanStringWithLayout(elem_shape)); + prefix = ", "; + } + text += ")"; + return text; + } else { + string layout; + if (!IsScalar(shape) && !IsOpaque(shape)) { + if (LayoutUtil::HasLayout(shape)) { + layout = tensorflow::strings::StrCat( + " ", LayoutUtil::HumanString(shape.layout())); + } else { + layout = " (no layout)"; + } + } + return tensorflow::strings::StrCat( + tensorflow::str_util::Lowercase( + PrimitiveType_Name(shape.element_type())), + "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]", layout); + } +} + +/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { + std::vector parameters; + for (auto& shape : program_shape.parameters()) { + const int i = parameters.size(); + parameters.push_back( + tensorflow::strings::StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); + } + return tensorflow::strings::StrCat( + "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); +} + +/* static */ StatusOr ShapeUtil::ParseShapeString(const string& s) { + string element_type_string; + string dimensions_string; + string layout_string; + if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?", + &element_type_string, &dimensions_string, + &layout_string)) { + auto comma_list_to_int64s = + [&s](const string& input) -> StatusOr> { + std::vector results; + for (const string& piece : tensorflow::str_util::Split(input, ',')) { + int64 element; + if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) { + return InvalidArgument( + "invalid value in parsed shape string: \"%s\" in \"%s\"", + piece.c_str(), s.c_str()); + } + results.push_back(element); + } + return results; + }; + TF_ASSIGN_OR_RETURN(std::vector dimensions, + comma_list_to_int64s(dimensions_string)); + PrimitiveType primitive_type; + if (element_type_string == "f32") { + primitive_type = F32; + } else if (element_type_string == "s32") { + primitive_type = S32; + } else if (element_type_string == "u32") { + primitive_type = U32; + } else { + LOG(FATAL) << "unhandled element type string: " << element_type_string; + } + Shape result; + if (layout_string.empty()) { + result = ShapeUtil::MakeShape(primitive_type, dimensions); + } else { + TF_ASSIGN_OR_RETURN(std::vector min2maj, + comma_list_to_int64s(layout_string)); + TF_RET_CHECK(dimensions.size() == min2maj.size()); + result = + ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + } + TF_DCHECK_OK(ValidateShape(result)); + return result; + } + + return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str()); +} + +/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, + const Shape& rhs) { + return ContainersEqual(lhs.dimensions(), rhs.dimensions()); +} + +/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); + } + return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); +} + +/* static */ int64 ShapeUtil::GetDimension(const Shape& shape, + int64 dimension_number) { + return shape.dimensions(GetDimensionNumber(shape, dimension_number)); +} + +/* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, + int64 dimension_number) { + if (dimension_number < 0) { + dimension_number += ShapeUtil::Rank(shape); + } + CHECK_GE(dimension_number, 0); + return dimension_number; +} + +/* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType( + PrimitiveType primitive_type) { + switch (primitive_type) { + case PRED: + return sizeof(int8); + case TUPLE: + LOG(FATAL) << "tuples have no definitive size"; + case OPAQUE: + LOG(FATAL) << "opaque have no definitive size"; + case S8: + return sizeof(int8); + case S16: + return sizeof(int16); + case S32: + return sizeof(int32); + case S64: + return sizeof(int64); + case U8: + return sizeof(uint8); + case U16: + return sizeof(uint16); + case U32: + return sizeof(uint32); + case U64: + return sizeof(uint64); + case F16: + return sizeof(float) / 2; + case F32: + return sizeof(float); + case F64: + return sizeof(double); + default: + LOG(FATAL) << "Unhandled primitive type " << primitive_type; + } +} + +/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, + int64 pointer_size) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK_NE(OPAQUE, shape.element_type()); + if (shape.element_type() == TUPLE) { + return pointer_size * shape.tuple_shapes_size(); + } + int64 allocated_element_count; + if (shape.layout().padded_dimensions_size() > 0) { + CHECK_EQ(ShapeUtil::Rank(shape), shape.layout().padded_dimensions_size()); + allocated_element_count = 1; + for (int64 dimension_size : shape.layout().padded_dimensions()) { + allocated_element_count *= dimension_size; + } + } else { + allocated_element_count = ElementsIn(shape); + } + return allocated_element_count * + ByteSizeOfPrimitiveType(shape.element_type()); +} + +/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); +} + +/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( + const Shape& shape) { + if (shape.element_type() == TUPLE) { + // Tuple shape. + if (ShapeUtil::Rank(shape) != 0) { + return InvalidArgument("tuples must be rank-0; got rank %lld", + ShapeUtil::Rank(shape)); + } + if (shape.dimensions_size() != 0) { + return InvalidArgument("tuples must not have dimensions specified"); + } + for (auto& element_shape : shape.tuple_shapes()) { + TF_RETURN_IF_ERROR( + ValidateShapeWithOptionalLayoutInternal(element_shape)); + } + return Status::OK(); + } + + // Non-tuple shape. + if (shape.tuple_shapes_size() > 0) { + return InvalidArgument("non-tuple shape has tuple_shapes field"); + } + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } + if (ShapeUtil::Rank(shape) != shape.dimensions_size()) { + return InvalidArgument( + "shape's rank is mismatched with dimension count; rank=%lld " + "dimensions_size=%d", + ShapeUtil::Rank(shape), shape.dimensions_size()); + } + for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + int64 dimension = shape.dimensions(i); + if (dimension < 0) { + return InvalidArgument( + "shape's dimensions must not be < 0; dimension at index %lld was " + "%lld", + i, dimension); + } + } + + return Status::OK(); +} + +/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout( + const Shape& shape) { + if (LayoutUtil::HasLayout(shape)) { + // Since a layout is present, upgrade to the full set of invariant checks. + return ValidateShape(shape); + } + return ValidateShapeWithOptionalLayoutInternal(shape); +} + +/* static */ Status ShapeUtil::ValidateShape(const Shape& shape) { + TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape)); + + return LayoutUtil::ValidateLayoutInShape(shape); +} + +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, + PrimitiveType type) { + Shape new_shape = shape; + new_shape.set_element_type(type); + return new_shape; +} + +/* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape, + const ShapeIndex& index) { + const Shape* return_shape = &shape; + for (auto i : index) { + CHECK(IsTuple(*return_shape)); + return_shape = &return_shape->tuple_shapes(i); + } + return *return_shape; +} + +/* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape, + const ShapeIndex& index) { + Shape* return_shape = shape; + for (auto i : index) { + CHECK(IsTuple(*return_shape)); + return_shape = return_shape->mutable_tuple_shapes(i); + } + return return_shape; +} + +/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { + std::vector dimension_sizes; + std::vector degenerate_dimensions; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + if (shape.dimensions(i) == 1) { + degenerate_dimensions.push_back(i); + } else { + dimension_sizes.push_back(shape.dimensions(i)); + } + } + + // Construct minor_to_major of stripped shape. The order of the non-degenerate + // dimensions should be preserved from the original shape. First, create + // vector of the non-degenerate dimensions from the original minor_to_major + // array. + std::vector minor_to_major; + for (int64 i : shape.layout().minor_to_major()) { + if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), + i) == degenerate_dimensions.end()) { + minor_to_major.push_back(i); + } + } + + // The dimensions in minor_to_major need to be renumbered to account for the + // degenerate dimensions which have removed. Decrement each dimension number + // once for each degenerate dimension which has a smaller number. + for (int i = 0; i < minor_to_major.size(); ++i) { + int adjustment = 0; + for (int64 dim : degenerate_dimensions) { + if (minor_to_major[i] > dim) { + adjustment++; + } + } + minor_to_major[i] -= adjustment; + } + + { + std::vector dims(minor_to_major.size()); + std::iota(dims.begin(), dims.end(), 0); + DCHECK(minor_to_major.size() == dims.size() && + std::is_permutation(minor_to_major.begin(), minor_to_major.end(), + dims.begin())); + } + Shape stripped_shape = + shape.has_layout() ? MakeShapeWithLayout(shape.element_type(), + dimension_sizes, minor_to_major) + : MakeShape(shape.element_type(), dimension_sizes); + + VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); + VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); + return stripped_shape; +} + +namespace { + +// Helper for ForEachSubshape which visits the subshapes of the given shape in +// DFS pre-order starting with the index. +Status ForEachSubshapeHelper(const Shape& shape, + const ShapeUtil::VisitorFunction func, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(shape, *index)); + if (ShapeUtil::IsTuple(shape)) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachSubshapeHelper( + ShapeUtil::GetTupleElementShape(shape, i), func, index)); + index->pop_back(); + } + } + return Status::OK(); +} + +// Helper for ForEachMutableSubshape which visits the subshapes of the given +// shape in DFS pre-order starting with the index. +Status ForEachMutableSubshapeHelper( + Shape* shape, const ShapeUtil::MutatingVisitorFunction func, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(shape, *index)); + if (ShapeUtil::IsTuple(*shape)) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper( + shape->mutable_tuple_shapes(i), func, index)); + index->pop_back(); + } + } + return Status::OK(); +} + +} // namespace + +/* static */ Status ShapeUtil::ForEachSubshape(const Shape& shape, + VisitorFunction func) { + ShapeIndex index; + return ForEachSubshapeHelper(shape, func, &index); +} + +/* static */ Status ShapeUtil::ForEachMutableSubshape( + Shape* shape, MutatingVisitorFunction func) { + ShapeIndex index; + return ForEachMutableSubshapeHelper(shape, func, &index); +} + +/* static */ Shape ShapeUtil::PermuteDimensions( + tensorflow::gtl::ArraySlice permutation, const Shape& shape) { + Shape new_shape = shape; + new_shape.clear_dimensions(); + for (auto dim : Permute(permutation, shape.dimensions())) { + new_shape.add_dimensions(dim); + } + if (shape.has_layout()) { + new_shape.mutable_layout()->clear_minor_to_major(); + for (auto index : Permute(permutation, shape.layout().minor_to_major())) { + new_shape.mutable_layout()->add_minor_to_major(index); + } + } + return new_shape; +} + +/* static */ std::tuple, std::vector> +ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, + const Shape& shape_post) { + auto nil = std::make_tuple(false, std::vector(), std::vector()); + + std::vector deleted_indices; + std::vector inserted_indices; + // Returns false if any input/output index between prior_unmodified_dim_pair + // and unmodified_dim_pair have size >1. Otherwise, returns true and appends + // the degerenate input/output dimensions in the gap to + // deleted_indices/inserted_indices respectively. + auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices, + &inserted_indices]( + std::pair prior_unmodified_dim_pair, + std::pair unmodified_dim_pair) { + for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; + modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) { + if (shape_pre.dimensions(modified_input_dim) > 1) { + return false; + } + deleted_indices.push_back(modified_input_dim); + } + for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; + modified_output_dim < unmodified_dim_pair.second; + ++modified_output_dim) { + if (shape_post.dimensions(modified_output_dim) > 1) { + return false; + } + inserted_indices.push_back(modified_output_dim); + } + return true; + }; + + std::vector> unmodified_dims = + DimensionsUnmodifiedByReshape(shape_pre, shape_post); + // Returns nil if the reshape modifies any non-degenerate input/output + // dimension. DimensionsUnmodifiedByReshape gives us all unmodified + // dimensions, so we only need to check whether dimensions in the gaps (thus + // modified) have size >1. + for (size_t i = 0; i <= unmodified_dims.size(); ++i) { + // Check (modified) dimensions between unmodified_dims[i-1] and + // unmodified_dims[i]. + auto prior_unmodified_dim_pair = + i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL); + auto unmodified_dim_pair = + i < unmodified_dims.size() + ? unmodified_dims[i] + : std::make_pair(ShapeUtil::Rank(shape_pre), + ShapeUtil::Rank(shape_post)); + if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { + return nil; + } + } + + return std::make_tuple(true, deleted_indices, inserted_indices); +} + +/* static */ std::vector> +ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, + const Shape& output_shape) { + // Returns nil if the input/output shape has zero elements. This is safe but + // might be too conservative. Not a big deal for now because IR emitted for + // zero-element shapes are often trivially optimizable without the help of + // this method. + if (ShapeUtil::ElementsIn(input_shape) == 0 || + ShapeUtil::ElementsIn(output_shape) == 0) { + return std::vector>(); + } + + std::vector> unmodified_dims; + int64 input_dim = 0; + int64 output_dim = 0; + + // A reshape preserves input_dim as output_dim iff + // 1. input_dim and output_dim have the same size. + // 2. The size of the input subarray from dimension 0 to input_dim-1 equals + // that of the output subarray from dimension 0 to output_dim-1. + VLOG(3) << "DimensionsUnmodifiedByReshape: input_shape=" + << ShapeUtil::HumanString(input_shape) + << ", output_shape=" << ShapeUtil::HumanString(output_shape); + while (input_dim < ShapeUtil::Rank(input_shape) && + output_dim < ShapeUtil::Rank(output_shape)) { + // partial_input_size is the product of sizes of input dimensions + // inclusively between the input_dim when this loop iteration starts and the + // current input_dim. partial_output_size is that of output dimensions. We + // compute these two values incrementally to save time. + int64 partial_input_size = input_shape.dimensions(input_dim); + int64 partial_output_size = output_shape.dimensions(output_dim); + // Move input_dim and output_dim forward until + // partial_input_size==partial_output_size. + while (partial_input_size != partial_output_size) { + if (partial_input_size < partial_output_size) { + ++input_dim; + partial_input_size *= input_shape.dimensions(input_dim); + } else { + ++output_dim; + partial_output_size *= output_shape.dimensions(output_dim); + } + } + CHECK_LT(input_dim, ShapeUtil::Rank(input_shape)); + CHECK_LT(output_dim, ShapeUtil::Rank(output_shape)); + if (input_shape.dimensions(input_dim) == + output_shape.dimensions(output_dim)) { + unmodified_dims.push_back({input_dim, output_dim}); + VLOG(3) << "Matching dimension pair: " << input_dim << ' ' << output_dim; + } + ++input_dim; + ++output_dim; + } + + return unmodified_dims; +} + +/* static */ bool ShapeUtil::TransposeIsBitcast( + const Shape& input_shape, const Shape& output_shape, + tensorflow::gtl::ArraySlice dimension_mapping) { + // Can't insert bitcasts without layout information. + if (!LayoutUtil::HasLayout(input_shape) && + !LayoutUtil::HasLayout(output_shape)) { + return false; + } + + // Padding is not handled. + if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) { + return false; + } + + // Check the reshape permutes the positions of each dimension in the + // minor-to-major order. positions[i]=k means dimension `i` is k-th minor. + // input_positions = apply(dimension_mapping, output_positions) + // + // Because the positions of each dimension are the inverse permutation of the + // minor-to-major order, the above check is equivalent to + // inverse(input_dimensions) = + // apply(dimension_mapping, inverse(output_dimensions)) + // # `I` indicates identity permutation. + // apply(input_dimensions, I) = + // apply(dimension_mapping, apply(output_dimensions, I)) + // apply(input_dimensions, I) = + // apply((dimension_mapping * output_dimensions), I) + // input_dimensions = dimension_mapping * output_dimensions + return ContainersEqual( + ComposePermutations(dimension_mapping, + AsInt64Slice(output_shape.layout().minor_to_major())), + input_shape.layout().minor_to_major()); +} + +/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, + const Shape& output_shape) { + // Can't convert reshapes into bitcasts without layout information. + if (!LayoutUtil::HasLayout(input_shape) || + !LayoutUtil::HasLayout(output_shape)) { + return false; + } + + // Padding is not handled. + if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) { + return false; + } + + CHECK_EQ(ShapeUtil::ElementsIn(input_shape), + ShapeUtil::ElementsIn(output_shape)); + if (ShapeUtil::ElementsIn(input_shape) == 0) { + return true; + } + + // TL;DR: The rest of the method checks that the reshape does not change the + // physical location of any unit input or output index. Unit indices have + // exactly one dimension that equals 1 and other dimensions 0. This condition + // is necessary for the reshape to be a bitcast, because a bitcast-equivalent + // reshape shouldn't change the physical location of any element. It is also a + // sufficient condition as is proved below (note: many details are omitted for + // space). + // + // Definitions: + // + // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means + // the size of i-th least significant dimension of IS or OS (this is opposite + // to how we define the index of Shape::dimensions()). + // + // * Given an input or output index I, denote by p(I) I's physical linear + // index (or physical index for short) and l(I) I's logical linear index (or + // logical index for short). + // + // * Given a logical index k, denote by II(k) the input index whose linear + // index is k, and OI(k) the corresponding output index. + // + // * Denote by IT[i] the increment of physical index if i-th dimension of the + // input index is increased by 1. Similarly, OT[i] means the increment if i-th + // dimension of the output index is increased by 1. Note that IT[i] or OT[i] + // is a function of IS or OS and the layout, and not dependent on the specific + // input or output index. + // + // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove + // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by + // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is + // to prove, with every increment on k, the above formula still holds. + // + // First, suppose reshaping from IS to OS is non-factorizable (we discuss + // refactorizable reshapes later). A reshape from IS to OS is factorizable, if + // there exists (i,j) such that + // + // 0<=i<=|IS| + // 0<=j<=|OS| + // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end) + // product(IS[i], IS[i+1], ..., IS[|IS|-1]) + // = product(OS[j], OS[j+1], ..., OS[|OS|-1]) + // + // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0)) + // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are + // unit indices which are already tested. This also means IT[0]=OT[0] + // because p(II(1))=IT[0] and p(OI(1))=OT[0]. + // + // Furthermore, p(II(k))=p(OI(k)) for kOS[0]. + // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From + // logical index k-1 to logical index k, dimension 1 of the input index + // is increased by 1 and dimension 0 is reset to 0 thus decreased by + // IS[0]-1. Therefore, the physical input index is increased by + // + // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0] + // + // Because IS[0] [63x6x5] + // + // can be factorized into + // + // [7x9] -> [63] and [2x15] -> [6x5]. + // + // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the + // same logical linear index. According to the factorization, we know + // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for + // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a + // similar proof, with the increment of the logical index set to + // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove + // p(x3,x2,0,0)=p(y2,0,0) too. Therefore, + // + // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0) + // = p(y2,0,0) + p(0,0,y1,y0) + // = p(y2,y1,y0) + // + // check_input_unit_indices checks one way of the condition: each input unit + // index is mapped to an output index with the same physical location. This + // lambda will be called again with input_shape and output_shape reversed to + // check the other way. + auto check_input_unit_indices = [](const Shape& input_shape, + const Shape& output_shape) { + // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions" + // as input_shape/output_shape and the dimension-0-major layout. These two + // shapes are used for conversion between logical linear indices and + // multi-dimensional indices. + Shape input_shape_dim0_major = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); + Shape output_shape_dim0_major = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), + AsInt64Slice(output_shape.dimensions())); + + for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); + ++input_dim) { + if (input_shape.dimensions(input_dim) <= 1) { + continue; + } + + std::vector input_unit_index(ShapeUtil::Rank(input_shape), 0); + input_unit_index[input_dim] = 1; + int64 logical_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, + input_unit_index); + // output_index has the same logical linear index as input_unit_index. + std::vector output_index = + IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major, + logical_linear_index); + // Check input_unit_index and output_index have the same physical linear + // index. + if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape, + input_unit_index) != + IndexUtil::MultidimensionalIndexToLinearIndex(output_shape, + output_index)) { + return false; + } + } + return true; + }; + return check_input_unit_indices(input_shape, output_shape) && + check_input_unit_indices(output_shape, input_shape); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h new file mode 100644 index 0000000000..35fd714b0b --- /dev/null +++ b/tensorflow/compiler/xla/shape_util.h @@ -0,0 +1,393 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Shapes are protobuf messages, so this utility header offers a bunch of +// functionality for querying / poking at them. + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// An index for specifying a particular nested subshape within a shape. Used in +// ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data +// structures (trees) and ShapeIndex defines a path through the tree where each +// element of ShapeIndex indexes into a tuple (or nested tuple) within the +// shape. For a non-nested tuple, an index has a single element. For example, +// given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index +// {1} corresponds to array b. For a nested tuple, the index can have more than +// one element. For the nested tuple (a, (b, c, d), e) below are the values +// corresponding to the given indices: +// +// index {0} : array a +// index {1, 2} : array d +// index {2} : array e +// index {0, 0} : invalid index (element at {0} is an array not a tuple) +// +// For indexing into array shapes, the index is always trivially empty, ie {}. +// +// ShapeIndex is a trivial wrapper around std::vector with a minimum number of +// methods implemented. +class ShapeIndex { + public: + ShapeIndex() = default; + ShapeIndex(std::initializer_list init) : indices_(init) {} + + bool empty() const { return indices_.empty(); } + size_t size() const { return indices_.size(); } + void push_back(int64 value) { indices_.push_back(value); } + void pop_back() { indices_.pop_back(); } + + std::vector::const_iterator begin() const { return indices_.begin(); } + std::vector::const_iterator end() const { return indices_.end(); } + std::vector::iterator begin() { return indices_.begin(); } + std::vector::iterator end() { return indices_.end(); } + + const int64& operator[](size_t i) const { return indices_[i]; } + int64& operator[](size_t i) { return indices_[i]; } + + bool operator==(const ShapeIndex& other) const { + return indices_ == other.indices_; + } + bool operator!=(const ShapeIndex& other) const { return !(*this == other); } + + private: + std::vector indices_; +}; + +// Namespaced collection of (static) shape utilities. +// +// These are all effectively convenience functions for testing/tweaking proto +// properties, which do invariant checks before / after the operation. +class ShapeUtil { + public: + // Returns the number of elements are contained within the provided shape; + // e.g. for rank 0 (scalars) the result is always 1. + static int64 ElementsIn(const Shape& shape); + + // Returns true if 'shape' has zero elements. + static bool HasZeroElements(const Shape& shape); + + // Returns the number of bytes required for an allocation of shape. The + // |pointer_size| parameter is used for calculating the size of tuple + // shapes. This includes only the size of the top-level buffer. For example, a + // tuple is stored as an array of pointers to other buffers. In this case, + // this method only returns the size of the pointer array. + static int64 ByteSizeOf(const Shape& shape, int64 pointer_size); + + // Returns the number of bytes required for an allocation of shape. + // The calculation for tuple shapes assumes that we are utilizing host + // pointers. + // Precondition: !ShapeUtil::IsOpaque(shape) + static int64 ByteSizeOf(const Shape& shape); + + // Returns the number of bytes used to store the primitive_type. + // + // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); + + // Returns a human-readable string that represents the given shape, with or + // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". + static string HumanString(const Shape& shape); + static string HumanStringWithLayout(const Shape& shape); + + // As above, but for program shapes, returns a string for the form: + // + // (param_name: f32[42x12], ...) -> f32[24x42] + static string HumanString(const ProgramShape& shape); + + // Parses a ShapeUtil::HumanString-format shape string back into a shape + // object. + static StatusOr ParseShapeString(const string& s); + + // Returns whether the LHS and RHS shapes have the same dimensions; note: does + // not check element type. + static bool SameDimensions(const Shape& lhs, const Shape& rhs); + + // Returns whether the lhs and rhs shapes have the same element type. + static bool SameElementType(const Shape& lhs, const Shape& rhs) { + return lhs.element_type() == rhs.element_type(); + } + + // Returns true if the rank, dimension sizes, and element type are + // identical. Layout is ignored. Tuple elements are compared recursively for + // compatibility. + static bool Compatible(const Shape& lhs, const Shape& rhs); + + // Returns whether the lhs and rhs shapes are identical protobufs. + static bool Equal(const Shape& lhs, const Shape& rhs); + + // Returns the rank (number of dimensions) of the given shape. + static int64 Rank(const Shape& shape) { return shape.dimensions_size(); } + + // Returns the number of dimensions for which the dimension is not (trivially) + // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just + // fluff. Note that zero dimensions are included in the true rank, e.g., + // f32[3,0,1] has a true rank of 2D. + static int64 TrueRank(const Shape& shape); + + static ProgramShape MakeProgramShape(std::initializer_list parameters, + Shape result); + + //////////////////// + // Scalar-specific + + static bool IsScalar(const Shape& shape) { + return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + } + static bool IsEffectiveScalar(const Shape& shape) { + return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + } + static bool IsScalarF32(const Shape& shape); + + // Extracts the size of the shape's dimension at dimension number + // GetDimensionNumber(dimension_number). + static int64 GetDimension(const Shape& shape, int64 dimension_number); + + // Resolves a dimension number, supporting negative indexing. + // + // Negative indexing has similar semantics to Python. For an N-dimensional + // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to + // N-2, and so on. + // + // This function always returns a positive dimension number for any given + // dimension_number (which itself can be negative). + static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number); + + // Returns a shape with the same dimensions as the original, but with the + // element type changed to type. + static Shape ChangeElementType(const Shape& original, PrimitiveType type); + + // Creates a tuple shape from a slice of element shapes within the tuple. + static Shape MakeTupleShape(tensorflow::gtl::ArraySlice shapes); + + // Creates an opaque shape. These are generally used for threading a context + // into a custom operation. + static Shape MakeOpaqueShape(); + + // Appends a shape to the given tuple. + static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); + + // Appends a major dimension to the shape with the given bound. + static void AppendMajorDimension(int bound, Shape* shape); + + // Returns an empty tuple shape. Can be used to indicate side-effects. + static Shape MakeNil() { return MakeTupleShape({}); } + + // Constructs a new shape with the given element type and sequence of + // dimensions. + static Shape MakeShape(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions); + + // Constructs a new shape with the given minor_to_major order in its Layout. + // Returns a value shape such that shape.has_layout(). + static Shape MakeShapeWithLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major); + + // Constructs a new shape with major-first layout. + static Shape MakeShapeWithMonotonicDim0MajorLayout( + PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions); + + // Returns a new shape with major-first layout that has the same layout of + // elements with a different shape. + static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + + // As MakeShape, but the object to write to is passed in. + static void PopulateShape(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions, + Shape* shape); + + // Validates that the provided shape satisfies invariants. + static Status ValidateShape(const Shape& shape); + + // Validates the the provided shape satisfies invariants, except those that + // pertain to layout. + // + // Layout is optional for client-provided shapes, so that the compiler may + // determine and assign an optimized layout. + static Status ValidateShapeWithOptionalLayout(const Shape& shape); + + // Returns whether the element type of the shape is integral (signed or + // unsigned). Note that predicates are not considered integral here, since + // they are logical values. + static bool ElementIsIntegral(const Shape& shape); + + // Returns whether the element type of the shape is floating point. + static bool ElementIsFloating(const Shape& shape); + + // Returns whether the element type has the given bit width. + static bool ElementHasBitWidth(const Shape& shape, int bits); + + // Returns whether the element type of the shape is integral and has + // the specified number of bits. + static bool ElementIsIntegralWithBits(const Shape& shape, int bits); + + // Returns whether the element type of the shape is signed. Note + // that floating point numbers are signed. + static bool ElementIsSigned(const Shape& shape); + + // Returns whether the shape is a tuple. + static bool IsTuple(const Shape& shape); + + // Returns whether the shape is an array. + static bool IsArray(const Shape& shape); + + // Returns whether the shape is an opaque. + static bool IsOpaque(const Shape& shape); + + // Returns whether the shape is a tuple with at least one element which is + // also a tuple. + static bool IsNestedTuple(const Shape& shape); + + // Returns true if shape is an empty tuple. + static bool IsEmptyTuple(const Shape& shape); + + // Returns true if shape is an empty tuple, or is an array with no elements. + static bool IsNil(const Shape& shape); + + // Returns the number of elements in the given tuple shape. + // Precondition: IsTuple(shape) + static int64 TupleElementCount(const Shape& shape); + + // Returns the tuple element shape at given index. + // Precondition: IsTuple(shape) && TupleElementCount(shape) > index + static const Shape& GetTupleElementShape(const Shape& shape, int64 index); + + // Slices tuple elements in the range [start, limit) and returns a new tuple + // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). + static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); + + // Shorthand for testing whether a shape is of a given element type and + // sequence of dimensions. + static bool ShapeIs(const Shape& shape, PrimitiveType element_type, + std::initializer_list dimensions); + + // GetSubshape and GetMutableSubshape return a particular nested Shape within + // the given Shape argument. + static const Shape& GetSubshape(const Shape& shape, const ShapeIndex& index); + static Shape* GetMutableSubshape(Shape* shape, const ShapeIndex& index); + + // Calls the given visitor function for each subshape of the given shape. + // Returns early if an error status is returned. Subshapes are visited in DFS + // pre-order starting with the entire shape (index {}). + using VisitorFunction = std::function; + static Status ForEachSubshape(const Shape& shape, VisitorFunction func); + + // Mutating variant of ForEachSubshape. + using MutatingVisitorFunction = + std::function; + static Status ForEachMutableSubshape(Shape* shape, + MutatingVisitorFunction func); + + // Removes all degenerate dimensions (size one) from the given shape. The + // stripped minor_to_major preserves the relative ordering of non-degenerate + // dimensions. The stripped shape has the property that the underlying + // representation (bits in memory) for the stripped shape is the same as the + // original shape modulo padding. Examples: + // + // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} + // stripped shape: F32 [2], minor_to_major = {0} + // + // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} + // stripped shape: F32 [6, 5], minor_to_major = {1, 0} + // + // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} + // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} + // + // input shape: F32 [1, 1], minor_to_major = {0, 1} + // stripped shape: F32 [], minor_to_major = {} + // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + static Shape StripDegenerateDimensions(const Shape& shape); + + // Permutes the dimensions by the given permutation, so + // return_value.dimensions[permutation[i]] = argument.dimensions[i] + static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, + const Shape& shape); + + // If we can go from `shape_pre` to `shape_post` by merely inserting or + // deleting 1-sized dimensions, return the indices in `shape_pre` of the + // deleted dimensions and the indices in `dims_post` of the inserted + // dimensions. + // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and + // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s + // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each + // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and + // `j`s less than `k` for all other `k`, we return the `i`s and `j`s. + // For another example, if `shape_pre = shape_post = {}`, we return `{}`. + static std::tuple, std::vector> + InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, + const Shape& shape_post); + + // Suppose a reshape transforms input_shape to output shape. Returns a vector + // of pairs that indicate the input and output dimensions that this reshape + // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in + // the returned vector, the reshape transforms any input index whose I-th + // dimension is x to an output index whose O-th dimension is x too. + // + // Post-condition: the returned vector is sorted (by both input and output + // dimensions because input and output dimensions have the same order). + // + // Example: + // input shape = T[a, b, x, y, cd] + // output shape = T[ab, x, 1, y, c, d] + // return value = {{2, 1}, {3, 3}} + // + // The two pairs represent the input and output dimension of size x and + // those of size y. + static std::vector> DimensionsUnmodifiedByReshape( + const Shape& input_shape, const Shape& output_shape); + + // Returns whether a transpose from input_shape to output_shape with dimension + // mapping "dimension_mapping" produces a result which is bit-wise identical + // to its input and thus may be replaced with a bitcast. + static bool TransposeIsBitcast( + const Shape& input_shape, const Shape& output_shape, + tensorflow::gtl::ArraySlice dimension_mapping); + + // Returns whether a reshape from "input_shape" to "output_shape" is a + // bitcast. + static bool ReshapeIsBitcast(const Shape& input_shape, + const Shape& output_shape); + + private: + // Recursive helper for comparing the equality of two shapes. Returns true if + // the shapes are the same. If compare_layouts is true, then layouts must also + // match. + static bool CompareShapes(const Shape& lhs, const Shape& rhs, + bool compare_layouts); + + // Validates all of the non-layout properties of the shape -- this is a helper + // used by both the layout-optional and layout-required public method. + static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); + + TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc new file mode 100644 index 0000000000..4e8a496e7e --- /dev/null +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { + Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); + EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); + EXPECT_EQ(2, ShapeUtil::GetDimension(matrix, -2)); +} + +TEST(ShapeUtilTest, GetDimensionHelperExampleInDocumentationTest) { + auto shape = ShapeUtil::MakeShape(F32, {1, 2, 3, 4}); + ASSERT_EQ(4, ShapeUtil::GetDimension(shape, -1)); +} + +TEST(ShapeUtilTest, NegativeIndexOobFails) { + Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); + ASSERT_DEATH(ShapeUtil::GetDimension(matrix, -3), "dimension_number >= 0"); +} + +TEST(ShapeUtilTest, Rank1DimensionIndexing) { + Shape shape = ShapeUtil::MakeShape(F32, {3}); + ASSERT_EQ(3, shape.dimensions(0)); +} + +TEST(ShapeUtilTest, Rank2DimensionIndexing) { + Shape shape = ShapeUtil::MakeShape(F32, {3, 2}); + ASSERT_EQ(2, shape.dimensions(1)); + ASSERT_EQ(3, shape.dimensions(0)); +} + +TEST(ShapeUtilTest, Rank3DimensionIndexing) { + Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7}); + ASSERT_EQ(7, shape.dimensions(2)); + ASSERT_EQ(2, shape.dimensions(1)); + ASSERT_EQ(3, shape.dimensions(0)); +} + +TEST(ShapeUtilTest, Rank4DimensionIndexing) { + Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7, 8}); + ASSERT_EQ(8, shape.dimensions(3)); + ASSERT_EQ(7, shape.dimensions(2)); + ASSERT_EQ(2, shape.dimensions(1)); + ASSERT_EQ(3, shape.dimensions(0)); +} + +TEST(ShapeUtilTest, ParseShapeStringR2F32) { + string shape_string = "f32[123,456]"; + Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, CompatibleIdenticalShapes) { + Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); + ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2)); +} + +TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { + Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); + auto layout_1 = shape_1.mutable_layout(); + layout_1->clear_minor_to_major(); + layout_1->add_minor_to_major(0); + layout_1->add_minor_to_major(1); + + Shape shape_2 = ShapeUtil::MakeShape(F32, {3, 2}); + auto layout_2 = shape_2.mutable_layout(); + layout_2->clear_minor_to_major(); + layout_2->add_minor_to_major(1); + layout_2->add_minor_to_major(0); + + EXPECT_FALSE(ShapeUtil::Equal(shape_1, shape_2)); + EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2)); +} + +TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { + Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); + Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2}); + EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); +} + +TEST(ShapeUtilTest, CompatibleTuples) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); + EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2)); +} + +TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); + EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); +} + +TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})}); + EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); +} + +TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {4, 2})}); + EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); +} + +TEST(ShapeUtilTest, EmptyLayoutEqualsMissingLayout) { + // A shape with a missing layout should be equal to a shape with an empty + // layout. + Shape scalar1 = ShapeUtil::MakeShape(F32, {}); + Shape scalar2 = ShapeUtil::MakeShape(F32, {}); + + EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); + + scalar1.clear_layout(); // Remove layout field. + scalar2.mutable_layout(); // Create empty layout field. + + EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); +} + +TEST(ShapeUtilTest, ScalarUnpopulatedLayoutEqualsScalarLayout) { + Shape scalar_unpopulated = ShapeUtil::MakeShape(F32, {}); + scalar_unpopulated.clear_layout(); + ASSERT_FALSE(scalar_unpopulated.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_unpopulated); + + const Shape scalar_populated = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + ASSERT_TRUE(scalar_populated.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_populated); + + EXPECT_TRUE(ShapeUtil::Equal(scalar_unpopulated, scalar_populated)); +} + +TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { + EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32)); + EXPECT_EQ(4, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {}))); + EXPECT_EQ(800, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {10, 20}))); + + EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64)); + EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {}))); + EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20}))); +} + +TEST(ShapeUtilTest, ByteSizeOfWithPadding) { + EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32)); + Shape shape = ShapeUtil::MakeShape(F32, {10, 20}); + EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape)); + + shape.mutable_layout()->add_padded_dimensions(15); + shape.mutable_layout()->add_padded_dimensions(21); + EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); +} + +TEST(ShapeUtilTest, NestedTuple) { + EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({}))); + EXPECT_FALSE(ShapeUtil::IsNestedTuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); + EXPECT_TRUE(ShapeUtil::IsNestedTuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({})}))); + EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}))); + EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeTupleShape({})}))); + EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeShape(S32, {})}))); + EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeTupleShape({})}))); +} + +TEST(ShapeUtilTest, ElementsIn) { + EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {}))); + EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0}))); + EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1}))); + EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 1}))); + EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2}))); + EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2, 1}))); + EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 5}))); + EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 0, 5}))); + EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0, 3, 0}))); + EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 3, 5}))); + EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); +} + +TEST(ShapeUtilTest, HasZeroElements) { + EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {}))); + EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0}))); + EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1}))); + EXPECT_EQ(false, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1}))); + EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2}))); + EXPECT_EQ(false, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1}))); + EXPECT_EQ(false, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5}))); + EXPECT_EQ(true, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5}))); + EXPECT_EQ(true, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0}))); + EXPECT_EQ(false, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5}))); + EXPECT_EQ(false, + ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17}))); +} + +TEST(ShapeUtilTest, SameDimensions) { + EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(S32, {}))); + EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), + ShapeUtil::MakeShape(S32, {1}))); + EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0}), + ShapeUtil::MakeShape(S32, {0}))); + EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {2}), + ShapeUtil::MakeShape(S32, {2}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), + ShapeUtil::MakeShape(F32, {2}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0, 0}), + ShapeUtil::MakeShape(F32, {0}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), + ShapeUtil::MakeShape(F32, {1, 1}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {1}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), + ShapeUtil::MakeShape(F32, {1, 1}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), + ShapeUtil::MakeShape(F32, {1, 0}))); + EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1, 1}), + ShapeUtil::MakeShape(F32, {1, 2}))); +} + +TEST(ShapeUtilTest, GetSubshape) { + // Test array shape. + Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123}); + EXPECT_TRUE( + ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(array_shape, {}))); + EXPECT_TRUE(ShapeUtil::Equal( + array_shape, *ShapeUtil::GetMutableSubshape(&array_shape, {}))); + + // Test tuple shape. + Shape tuple_shape = + ShapeUtil::MakeTupleShape({array_shape, array_shape, array_shape}); + EXPECT_TRUE( + ShapeUtil::Equal(tuple_shape, ShapeUtil::GetSubshape(tuple_shape, {}))); + EXPECT_TRUE( + ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {0}))); + EXPECT_TRUE( + ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {1}))); + EXPECT_TRUE( + ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {2}))); + + // Test nested tuple shape. + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape( + {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({array_shape, array_shape}), + array_shape})}); + EXPECT_TRUE(ShapeUtil::Equal(nested_tuple_shape, + ShapeUtil::GetSubshape(nested_tuple_shape, {}))); + EXPECT_TRUE(ShapeUtil::Equal( + array_shape, ShapeUtil::GetSubshape(nested_tuple_shape, {0}))); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}), + ShapeUtil::GetSubshape(nested_tuple_shape, {1}))); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}), + ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0}))); +} + +TEST(ShapeUtilTest, HumanString) { + Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape scalar = ShapeUtil::MakeShape(F32, {}); + Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); + Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); + Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + + EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); + EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); + EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); + EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", + ShapeUtil::HumanString(tuple)); + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + ShapeUtil::HumanString(nested_tuple)); + + EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); + EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar)); + EXPECT_EQ("u32[1,2] {1,0}", ShapeUtil::HumanStringWithLayout(matrix)); + EXPECT_EQ("s32[3,4] {0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); + EXPECT_EQ("(opaque[], f32[], u32[1,2] {1,0}, s32[3,4] {0,1})", + ShapeUtil::HumanStringWithLayout(tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2] {1,0}, s32[3,4] {0,1}), u32[1,2] {1,0})", + ShapeUtil::HumanStringWithLayout(nested_tuple)); + + ProgramShape prog = ShapeUtil::MakeProgramShape( + {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); + EXPECT_EQ( + "((unknown): opaque[], " + "(unknown): f32[], " + "(unknown): u32[1,2], " + "(unknown): s32[3,4], " + "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + ShapeUtil::HumanString(prog)); + + prog.add_parameter_names("arg0"); + prog.add_parameter_names("scalar"); + prog.add_parameter_names("matrix"); + prog.add_parameter_names("matrix2"); + prog.add_parameter_names("tuple"); + prog.add_parameter_names("nested_tuple"); + EXPECT_EQ( + "(arg0: opaque[], " + "scalar: f32[], " + "matrix: u32[1,2], " + "matrix2: s32[3,4], " + "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + ShapeUtil::HumanString(prog)); +} + +TEST(ShapeUtilTest, ForEachSubshapeArray) { + const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + int calls = 0; + EXPECT_IS_OK(ShapeUtil::ForEachSubshape( + shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { + EXPECT_EQ(&shape, &subshape); + EXPECT_TRUE(index.empty()); + ++calls; + return tensorflow::Status::OK(); + })); + EXPECT_EQ(1, calls); +} + +TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) { + const Shape shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {42}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), + ShapeUtil::MakeShape(PRED, {33})})}); + int calls = 0; + EXPECT_IS_OK(ShapeUtil::ForEachSubshape( + shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { + EXPECT_TRUE( + ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index))); + if (calls == 0) { + // Visitation should go from outside in. + EXPECT_TRUE(index.empty()); + } else if (calls == 4) { + // Last visitation should be to the array with 33 elements. + EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape)); + } + ++calls; + return tensorflow::Status::OK(); + })); + EXPECT_EQ(5, calls); +} + +TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) { + Shape shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {42}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), + ShapeUtil::MakeShape(PRED, {33})})}); + int calls = 0; + EXPECT_IS_OK(ShapeUtil::ForEachMutableSubshape( + &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) { + // Pointer values should be equal + EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index)); + if (calls == 0) { + // Visitation should go from outside in. + EXPECT_TRUE(index.empty()); + } else if (calls == 4) { + // Last visitation should be to the array with 33 elements. + EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape)); + } + ++calls; + return tensorflow::Status::OK(); + })); + EXPECT_EQ(5, calls); +} + +TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { + Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4}); + Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1}); + Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12}); + EXPECT_TRUE(std::get<0>( + ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1))); + EXPECT_FALSE(std::get<0>( + ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); +} + +TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { + // All output dimensions should be unmodified. One of the input dimensions is + // modified because the input rank is larger by one. + EXPECT_EQ(3, + ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1})) + .size()); +} + +TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) { + // All input dimensions should be unmodified. One of the output dimensions is + // modified because the output rank is larger by one. + EXPECT_EQ(3, + ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1, 1})) + .size()); +} + +TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { + // The only matching dimension is the one with size 5. + // 4, 1, 3, 5, 6, 7 + // | + // 2, 6, 1, 5, 1, 42 + EXPECT_TRUE( + ContainersEqual(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), + ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), + std::vector>({{3, 3}}))); +} + +TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) { + for (bool input_is_row_major : {true, false}) { + for (bool output_is_row_major : {true, false}) { + Layout input_layout = input_is_row_major ? LayoutUtil::MakeLayout({1, 0}) + : LayoutUtil::MakeLayout({0, 1}); + Layout output_layout = output_is_row_major + ? LayoutUtil::MakeLayout({1, 0}) + : LayoutUtil::MakeLayout({0, 1}); + // Suppose the input is logically (i.e. ignoring its layout) + // 0 1 2 3 + // 4 5 6 7 + // 8 9 10 11 + // + // The reshape transforms the input to logically + // 0 1 + // 2 3 + // 4 5 + // 6 7 + // 8 9 + // 10 11 + // + // The input and the output have the same underlying data only if they + // are both row-major. + EXPECT_EQ( + ShapeUtil::ReshapeIsBitcast( + ShapeUtil::MakeShapeWithLayout( + F32, {3, 4}, AsInt64Slice(input_layout.minor_to_major())), + ShapeUtil::MakeShapeWithLayout( + F32, {6, 2}, AsInt64Slice(output_layout.minor_to_major()))), + input_is_row_major && output_is_row_major); + } + } +} + +TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast( + ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {1, 0, 2}), + ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); +} + +TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { + EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( + ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), + ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h new file mode 100644 index 0000000000..f3b561fada --- /dev/null +++ b/tensorflow/compiler/xla/status.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_STATUS_H_ +#define TENSORFLOW_COMPILER_XLA_STATUS_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +class TF_MUST_USE_RESULT Status; +#endif + +// Simple wrapper around tensorflow::Status that has the MUST_USE_RESULT +// annotation above. When tensorflow::Status adopts this annotation, this can +// simply become a "using tensorflow::Status". +class Status : public tensorflow::Status { + public: + static Status OK() { return tensorflow::Status::OK(); } + + // Note: implicit constructor. + Status(tensorflow::Status other) : tensorflow::Status(other) {} + + Status() : tensorflow::Status() {} + Status(tensorflow::error::Code code, tensorflow::StringPiece msg) + : tensorflow::Status(code, msg) {} +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_STATUS_H_ diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc new file mode 100644 index 0000000000..a6b1f9004f --- /dev/null +++ b/tensorflow/compiler/xla/status_macros.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/status_macros.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stacktrace.h" + +namespace xla { +namespace status_macros { + +static Status MakeStatus(tensorflow::error::Code code, const string& message) { + return Status(code, message); +} + +// Log the error at the given severity, optionally with a stack trace. +// If log_severity is NUM_SEVERITIES, nothing is logged. +static void LogError(const Status& status, const char* filename, int line, + int log_severity, bool should_log_stack_trace) { + if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) { + string stack_trace; + if (should_log_stack_trace) { + stack_trace = + tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace()); + } + switch (log_severity) { + case tensorflow::INFO: + LOG(INFO) << status << stack_trace; + break; + case tensorflow::WARNING: + LOG(WARNING) << status << stack_trace; + break; + case tensorflow::ERROR: + LOG(ERROR) << status << stack_trace; + break; + case tensorflow::FATAL: + LOG(FATAL) << status << stack_trace; + break; + case tensorflow::NUM_SEVERITIES: + break; + default: + LOG(FATAL) << "Unknown LOG severity " << log_severity; + } + } +} + +// Make a Status with a code, error message and payload, +// and also send it to LOG() using the given filename +// and line (unless should_log is false, or log_severity is +// NUM_SEVERITIES). If should_log_stack_trace is true, the stack +// trace is included in the log message (ignored if should_log is +// false). +static Status MakeError(const char* filename, int line, + tensorflow::error::Code code, const string& message, + bool should_log, int log_severity, + bool should_log_stack_trace) { + if (TF_PREDICT_FALSE(code == tensorflow::error::OK)) { + LOG(ERROR) << "Cannot create error with status OK"; + code = tensorflow::error::UNKNOWN; + } + const Status status = MakeStatus(code, message); + if (TF_PREDICT_TRUE(should_log)) { + LogError(status, filename, line, log_severity, should_log_stack_trace); + } + return status; +} + +// This method is written out-of-line rather than in the header to avoid +// generating a lot of inline code for error cases in all callers. +void MakeErrorStream::CheckNotDone() const { impl_->CheckNotDone(); } + +MakeErrorStream::Impl::Impl(const char* file, int line, + tensorflow::error::Code code, + MakeErrorStream* error_stream, + bool is_logged_by_default) + : file_(file), + line_(line), + code_(code), + is_done_(false), + should_log_(is_logged_by_default), + log_severity_(tensorflow::ERROR), + should_log_stack_trace_(false), + make_error_stream_with_output_wrapper_(error_stream) {} + +MakeErrorStream::Impl::Impl(const Status& status, + PriorMessageHandling prior_message_handling, + const char* file, int line, + MakeErrorStream* error_stream) + : file_(file), + line_(line), + // Make sure we show some error, even if the call is incorrect. + code_(!status.ok() ? status.code() : tensorflow::error::UNKNOWN), + prior_message_handling_(prior_message_handling), + prior_message_(status.error_message()), + is_done_(false), + // Error code type is not visible here, so we can't call + // IsLoggedByDefault. + should_log_(true), + log_severity_(tensorflow::ERROR), + should_log_stack_trace_(false), + make_error_stream_with_output_wrapper_(error_stream) { + DCHECK(!status.ok()) << "Attempted to append/prepend error text to status OK"; +} + +MakeErrorStream::Impl::~Impl() { + // Note: error messages refer to the public MakeErrorStream class. + + if (!is_done_) { + LOG(ERROR) << "MakeErrorStream destructed without getting Status: " << file_ + << ":" << line_ << " " << stream_.str(); + } +} + +Status MakeErrorStream::Impl::GetStatus() { + // Note: error messages refer to the public MakeErrorStream class. + + // Getting a Status object out more than once is not harmful, but + // it doesn't match the expected pattern, where the stream is constructed + // as a temporary, loaded with a message, and then casted to Status. + if (is_done_) { + LOG(ERROR) << "MakeErrorStream got Status more than once: " << file_ << ":" + << line_ << " " << stream_.str(); + } + + is_done_ = true; + + const string& stream_str = stream_.str(); + const string str = + prior_message_handling_ == kAppendToPriorMessage + ? tensorflow::strings::StrCat(prior_message_, stream_str) + : tensorflow::strings::StrCat(stream_str, prior_message_); + if (TF_PREDICT_FALSE(str.empty())) { + return MakeError(file_, line_, code_, + tensorflow::strings::StrCat( + str, "Error without message at ", file_, ":", line_), + true /* should_log */, + tensorflow::ERROR /* log_severity */, + should_log_stack_trace_); + } else { + return MakeError(file_, line_, code_, str, should_log_, log_severity_, + should_log_stack_trace_); + } +} + +void MakeErrorStream::Impl::CheckNotDone() const { + if (is_done_) { + LOG(ERROR) << "MakeErrorStream shift called after getting Status: " << file_ + << ":" << line_ << " " << stream_.str(); + } +} + +} // namespace status_macros +} // namespace xla diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h new file mode 100644 index 0000000000..aa12cda666 --- /dev/null +++ b/tensorflow/compiler/xla/status_macros.h @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_ +#define TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_ + +#include +#include // NOLINT +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace status_macros { + +// Stream object used to collect error messages in MAKE_ERROR macros +// or append error messages with APPEND_ERROR. It accepts any +// arguments with operator<< to build an error string, and then has an +// implicit cast operator to Status, which converts the +// logged string to a Status object and returns it, after logging the +// error. At least one call to operator<< is required; a compile time +// error will be generated if none are given. Errors will only be +// logged by default for certain status codes, as defined in +// IsLoggedByDefault. This class will give ERROR errors if you don't +// retrieve a Status exactly once before destruction. +// +// The class converts into an intermediate wrapper object +// MakeErrorStreamWithOutput to check that the error stream gets at least one +// item of input. +class MakeErrorStream { + public: + // Wrapper around MakeErrorStream that only allows for output. This + // is created as output of the first operator<< call on + // MakeErrorStream. The bare MakeErrorStream does not have a + // Status operator. The net effect of that is that you + // have to call operator<< at least once or else you'll get a + // compile time error. + class MakeErrorStreamWithOutput { + public: + explicit MakeErrorStreamWithOutput(MakeErrorStream* error_stream) + : wrapped_error_stream_(error_stream) {} + + template + MakeErrorStreamWithOutput& operator<<(const T& value) { + *wrapped_error_stream_ << value; + return *this; + } + + // Implicit cast operators to Status and StatusOr. + // Exactly one of these must be called exactly once before destruction. + operator Status() { return wrapped_error_stream_->GetStatus(); } + template + operator xla::StatusOr() { + return wrapped_error_stream_->GetStatus(); + } + + private: + MakeErrorStream* wrapped_error_stream_; + + TF_DISALLOW_COPY_AND_ASSIGN(MakeErrorStreamWithOutput); + }; + + // When starting from an existing error status, this determines whether we'll + // append or prepend to that status's error message. + enum PriorMessageHandling { kAppendToPriorMessage, kPrependToPriorMessage }; + + // Make an error with the given code. + template + MakeErrorStream(const char* file, int line, ERROR_CODE_TYPE code) + : impl_(new Impl(file, line, code, this, true)) {} + + template + MakeErrorStreamWithOutput& operator<<(const T& value) { + CheckNotDone(); + impl_->stream_ << value; + return impl_->make_error_stream_with_output_wrapper_; + } + + // When this message is logged (see with_logging()), include the stack trace. + MakeErrorStream& with_log_stack_trace() { + impl_->should_log_stack_trace_ = true; + return *this; + } + + // Adds RET_CHECK failure text to error message. + MakeErrorStreamWithOutput& add_ret_check_failure(const char* condition) { + return *this << "RET_CHECK failure (" << impl_->file_ << ":" << impl_->line_ + << ") " << condition << " "; + } + + private: + class Impl { + public: + Impl(const char* file, int line, tensorflow::error::Code code, + MakeErrorStream* error_stream, bool is_logged_by_default = true); + Impl(const Status& status, PriorMessageHandling prior_message_handling, + const char* file, int line, MakeErrorStream* error_stream); + + ~Impl(); + + // This must be called exactly once before destruction. + Status GetStatus(); + + void CheckNotDone() const; + + private: + const char* file_; + int line_; + tensorflow::error::Code code_; + + PriorMessageHandling prior_message_handling_ = kAppendToPriorMessage; + string prior_message_; + bool is_done_; // true after Status object has been returned + std::ostringstream stream_; + bool should_log_; + int log_severity_; + bool should_log_stack_trace_; + + // Wrapper around the MakeErrorStream object that has a + // Status conversion. The first << operator called on + // MakeErrorStream will return this object, and only this object + // can implicitly convert to Status. The net effect of + // this is that you'll get a compile time error if you call + // MAKE_ERROR etc. without adding any output. + MakeErrorStreamWithOutput make_error_stream_with_output_wrapper_; + + friend class MakeErrorStream; + TF_DISALLOW_COPY_AND_ASSIGN(Impl); + }; + + void CheckNotDone() const; + + // Returns the status. Used by MakeErrorStreamWithOutput. + Status GetStatus() const { return impl_->GetStatus(); } + + // Store the actual data on the heap to reduce stack frame sizes. + std::unique_ptr impl_; + + TF_DISALLOW_COPY_AND_ASSIGN(MakeErrorStream); +}; + +// Provides a conversion to bool so that it can be used inside an if statement +// that declares a variable. +class StatusAdaptorForMacros { + public: + explicit StatusAdaptorForMacros(Status status) : status_(std::move(status)) {} + + StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; + StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; + + explicit operator bool() const { return TF_PREDICT_TRUE(status_.ok()); } + + Status&& Consume() { return std::move(status_); } + + private: + Status status_; +}; + +} // namespace status_macros +} // namespace xla + +#define TF_RET_CHECK(condition) \ + while (TF_PREDICT_FALSE(!(condition))) \ + return xla::status_macros::MakeErrorStream(__FILE__, __LINE__, \ + tensorflow::error::INTERNAL) \ + .with_log_stack_trace() \ + .add_ret_check_failure(#condition) + +#define TF_ASSIGN_OR_ASSERT_OK(lhs, rexpr) \ + TF_ASSIGN_OR_ASSERT_OK_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ + rexpr); + +#define TF_ASSIGN_OR_ASSERT_OK_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ + lhs = statusor.ConsumeValueOrDie() + +#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) +#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y + +#define TF_ASSIGN_OR_RETURN(...) \ + TF_STATUS_MACRO_GET_VARIADIC_IMPL(__VA_ARGS__, TF_ASSIGN_OR_RETURN_IMPL_3, \ + TF_ASSIGN_OR_RETURN_IMPL_2) \ + (__VA_ARGS__) + +#define TF_STATUS_MACRO_GET_VARIADIC_IMPL(_1, _2, _3, NAME, ...) NAME + +#define TF_ASSIGN_OR_RETURN_IMPL_2(lhs, rexpr) \ + TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) + +#define TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) \ + TF_ASSIGN_OR_RETURN_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) + +#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (TF_PREDICT_FALSE(!statusor.ok())) { \ + return statusor.status(); \ + } \ + lhs = std::move(statusor.ValueOrDie()) + +#endif // TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_ diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc new file mode 100644 index 0000000000..4e7b9161db --- /dev/null +++ b/tensorflow/compiler/xla/status_macros_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/status_macros.h" + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +Status RetCheckFail() { + TF_RET_CHECK(2 > 3); + return Status::OK(); +} + +Status RetCheckFailWithExtraMessage() { + TF_RET_CHECK(2 > 3) << "extra message"; + return Status::OK(); +} + +Status RetCheckSuccess() { + TF_RET_CHECK(3 > 2); + return Status::OK(); +} + +TEST(StatusMacros, RetCheckFailing) { + Status status = RetCheckFail(); + EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); + EXPECT_MATCH(status.error_message(), + xla::testing::ContainsRegex("RET_CHECK failure.*2 > 3")); +} + +TEST(StatusMacros, RetCheckFailingWithExtraMessage) { + Status status = RetCheckFailWithExtraMessage(); + EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); + EXPECT_MATCH(status.error_message(), + xla::testing::ContainsRegex("RET_CHECK.*2 > 3 extra message")); +} + +TEST(StatusMacros, RetCheckSucceeding) { + Status status = RetCheckSuccess(); + EXPECT_IS_OK(status); +} + +StatusOr CreateIntSuccessfully() { return 42; } + +StatusOr CreateIntUnsuccessfully() { + return tensorflow::errors::Internal("foobar"); +} + +TEST(StatusMacros, AssignOrAssertOnOK) { + TF_ASSIGN_OR_ASSERT_OK(int result, CreateIntSuccessfully()); + EXPECT_EQ(42, result); +} + +Status ReturnStatusOK() { return Status::OK(); } + +Status ReturnStatusError() { return (tensorflow::errors::Internal("foobar")); } + +using StatusReturningFunction = std::function; + +StatusOr CallStatusReturningFunction(StatusReturningFunction func) { + TF_RETURN_IF_ERROR(func()); + return 42; +} + +TEST(StatusMacros, ReturnIfErrorOnOK) { + StatusOr rc = CallStatusReturningFunction(ReturnStatusOK); + EXPECT_IS_OK(rc); + EXPECT_EQ(42, rc.ConsumeValueOrDie()); +} + +TEST(StatusMacros, ReturnIfErrorOnError) { + StatusOr rc = CallStatusReturningFunction(ReturnStatusError); + EXPECT_FALSE(rc.ok()); + EXPECT_EQ(rc.status().code(), tensorflow::error::INTERNAL); +} + +TEST(StatusMacros, AssignOrReturnSuccessufully) { + Status status = []() { + TF_ASSIGN_OR_RETURN(int value, CreateIntSuccessfully()); + EXPECT_EQ(value, 42); + return Status::OK(); + }(); + EXPECT_IS_OK(status); +} + +TEST(StatusMacros, AssignOrReturnUnsuccessfully) { + Status status = []() { + TF_ASSIGN_OR_RETURN(int value, CreateIntUnsuccessfully()); + (void)value; + return Status::OK(); + }(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/compiler/xla/statusor.cc new file mode 100644 index 0000000000..36f08fc99f --- /dev/null +++ b/tensorflow/compiler/xla/statusor.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/statusor.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace internal { + +Status StatusOrHelper::HandleInvalidStatusCtorArg() { + const char* kMessage = + "Status::OK is not a valid constructor argument to StatusOr"; + LOG(ERROR) << kMessage; + // In optimized builds, we will fall back to tensorflow::error::INTERNAL. + return Status(tensorflow::error::INTERNAL, kMessage); +} + +Status StatusOrHelper::HandleNullObjectCtorArg() { + const char* kMessage = + "NULL is not a valid constructor argument to StatusOr"; + LOG(ERROR) << kMessage; + // In optimized builds, we will fall back to tensorflow::error::INTERNAL. + return Status(tensorflow::error::INTERNAL, kMessage); +} + +void StatusOrHelper::Crash(const Status& status) { + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; +} + +} // namespace internal +} // namespace xla diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h new file mode 100644 index 0000000000..8046a2216f --- /dev/null +++ b/tensorflow/compiler/xla/statusor.h @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// StatusOr is the union of a Status object and a T +// object. StatusOr models the concept of an object that is either a +// usable value, or an error Status explaining why such a value is +// not present. To this end, StatusOr does not allow its Status +// value to be Status::OK. Furthermore, the value of a StatusOr +// must not be null. This is enforced by a debug check in most cases, +// but even when it is not, clients must not set the value to null. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.ValueOrDie(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// StatusOr> result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return tensorflow::InvalidArgument("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. +#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_H_ +#define TENSORFLOW_COMPILER_XLA_STATUSOR_H_ + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +template +class TF_MUST_USE_RESULT StatusOr; +#endif + +template ::value> +class StatusOr { + template + friend class StatusOr; + + public: + typedef T element_type; + + // Construct a new StatusOr with Status::UNKNOWN status + StatusOr(); + + // Construct a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: status != Status::OK. This requirement is DCHECKed. + // In optimized builds, passing Status::OK here will have the effect + // of passing tensorflow::error::INTERNAL as a fallback. + StatusOr(Status status); // NOLINT + StatusOr(tensorflow::Status status); // NOLINT + + // Construct a new StatusOr with the given value. If T is a plain pointer, + // value must not be NULL. After calling this constructor, calls to + // ValueOrDie() will succeed, and calls to status() will return OK. + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: if T is a plain pointer, value != NULL. This requirement is + // DCHECKed. In optimized builds, passing a NULL pointer here will have + // the effect of passing tensorflow::error::INTERNAL as a fallback. + StatusOr(const T& value); // NOLINT + + // Copy constructor. + StatusOr(const StatusOr& other) = default; + + // Conversion copy constructor, T must be copy constructible from U + template + StatusOr(const StatusOr& other); + + // Assignment operator. + StatusOr& operator=(const StatusOr& other) = default; + + // Conversion assignment operator, T must be assignable from U + template + StatusOr& operator=(const StatusOr& other); + + // Move constructor and move-assignment operator. + StatusOr(StatusOr&& other) = default; + StatusOr& operator=(StatusOr&& other) = default; + + // Rvalue-reference overloads of the other constructors and assignment + // operators, to support move-only types and avoid unnecessary copying. + // + // Implementation note: we could avoid all these rvalue-reference overloads + // if the existing lvalue-reference overloads took their arguments by value + // instead. I think this would also let us omit the conversion assignment + // operator altogether, since we'd get the same functionality for free + // from the implicit conversion constructor and ordinary assignment. + // However, this could result in extra copy operations unless we use + // std::move to avoid them, and we can't use std::move because this code + // needs to be portable to C++03. + StatusOr(T&& value); // NOLINT + template + StatusOr(StatusOr&& other); + + // Returns a reference to our status. If this contains a T, then + // returns Status::OK. + const Status& status() const { return status_; } + + // Returns this->status().ok() + bool ok() const { return status_.ok(); } + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). + const T& ValueOrDie() const; + T& ValueOrDie(); + + // Moves our current value out of this object and returns it, or CHECK-fails + // if !this->ok(). + // Use of this method is discouraged; prefer std::move(statusor.ValueOrDie()) + // instead. + T ConsumeValueOrDie() { return std::move(ValueOrDie()); } + + private: + Status status_; + T value_; +}; + +// Partial specialization for when T is not copy-constructible. This uses all +// methods from the core implementation, but removes copy assignment and copy +// construction. +template +class StatusOr : public StatusOr { + public: + // Remove copies. + StatusOr(const StatusOr& other) = delete; + StatusOr& operator=(const StatusOr& other) = delete; + template + StatusOr(const StatusOr& other) = delete; + StatusOr(const T& value) = delete; + + // Use the superclass version for other constructors and operators. + StatusOr() = default; + StatusOr(StatusOr&& other) = default; + StatusOr& operator=(StatusOr&& other) = default; + StatusOr(T&& value) // NOLINT + : StatusOr::StatusOr(std::move(value)) {} + StatusOr(Status status) // NOLINT + : StatusOr::StatusOr(std::move(status)) {} + StatusOr(tensorflow::Status status) // NOLINT + : StatusOr::StatusOr(std::move(status)) {} + template + StatusOr(StatusOr&& other) // NOLINT + : StatusOr::StatusOr(std::move(other)) {} +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +namespace internal { + +class StatusOrHelper { + public: + // Move type-agnostic error handling to the .cc. + static Status HandleInvalidStatusCtorArg(); + static Status HandleNullObjectCtorArg(); + static void Crash(const Status& status); + + // Customized behavior for StatusOr vs. StatusOr + template + struct Specialize; +}; + +template +struct StatusOrHelper::Specialize { + // For non-pointer T, a reference can never be NULL. + static inline bool IsValueNull(const T& t) { return false; } +}; + +template +struct StatusOrHelper::Specialize { + static inline bool IsValueNull(const T* t) { return t == NULL; } +}; + +} // namespace internal + +template +inline StatusOr::StatusOr() + : status_(tensorflow::error::UNKNOWN, "") {} + +template +inline StatusOr::StatusOr(Status status) + : status_(std::move(status)) { + if (status_.ok()) { + status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg(); + } +} + +template +inline StatusOr::StatusOr(tensorflow::Status status) + : status_(status) { + if (status_.ok()) { + status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg(); + } +} + +template +inline StatusOr::StatusOr(const T& value) + : value_(value) { + if (internal::StatusOrHelper::Specialize::IsValueNull(value)) { + status_ = internal::StatusOrHelper::HandleNullObjectCtorArg(); + } +} + +template +template +inline StatusOr::StatusOr(const StatusOr& other) + : status_(other.status_), value_(other.value_) {} + +template +inline StatusOr::StatusOr(T&& value) + : value_(std::move(value)) { + if (internal::StatusOrHelper::Specialize::IsValueNull(value_)) { + status_ = internal::StatusOrHelper::HandleNullObjectCtorArg(); + } +} + +template +template +inline StatusOr::StatusOr(StatusOr&& other) + : status_(std::move(other.status_)), value_(std::move(other.value_)) {} + +template +inline const T& StatusOr::ValueOrDie() const { + if (!ok()) { + internal::StatusOrHelper::Crash(status()); + } + return value_; +} + +template +inline T& StatusOr::ValueOrDie() { + if (!status_.ok()) { + internal::StatusOrHelper::Crash(status()); + } + return value_; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_H_ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc new file mode 100644 index 0000000000..d98eb27933 --- /dev/null +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -0,0 +1,645 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Unit tests for StatusOr + +#include "tensorflow/compiler/xla/statusor.h" + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace xla { +namespace { + +using tensorflow::Status; + +class Base1 { + public: + virtual ~Base1() {} + int pad; +}; + +class Base2 { + public: + virtual ~Base2() {} + int yetotherpad; +}; + +class Derived : public Base1, public Base2 { + public: + virtual ~Derived() {} + int evenmorepad; +}; + +class CopyNoAssign { + public: + explicit CopyNoAssign(int value) : foo(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} + int foo; + + private: + const CopyNoAssign& operator=(const CopyNoAssign&); +}; + +StatusOr> ReturnUniquePtr() { + // Uses implicit constructor from T&& + return std::unique_ptr(new int(0)); +} + +TEST(StatusOr, ElementType) { + static_assert(std::is_same::element_type, int>(), ""); + static_assert(std::is_same::element_type, char>(), ""); +} + +TEST(StatusOr, TestMoveOnlyInitialization) { + StatusOr> thing(ReturnUniquePtr()); + ASSERT_TRUE(thing.ok()); + EXPECT_EQ(0, *thing.ValueOrDie()); + int* previous = thing.ValueOrDie().get(); + + thing = ReturnUniquePtr(); + EXPECT_TRUE(thing.ok()); + EXPECT_EQ(0, *thing.ValueOrDie()); + EXPECT_NE(previous, thing.ValueOrDie().get()); +} + +TEST(StatusOr, TestMoveOnlyStatusCtr) { + StatusOr> thing(tensorflow::errors::Cancelled("")); + ASSERT_FALSE(thing.ok()); +} + +TEST(StatusOr, TestMoveOnlyValueExtraction) { + StatusOr> thing(ReturnUniquePtr()); + ASSERT_TRUE(thing.ok()); + std::unique_ptr ptr = thing.ConsumeValueOrDie(); + EXPECT_EQ(0, *ptr); + + thing = std::move(ptr); + ptr = std::move(thing.ValueOrDie()); + EXPECT_EQ(0, *ptr); +} + +TEST(StatusOr, TestMoveOnlyConversion) { + StatusOr> const_thing(ReturnUniquePtr()); + EXPECT_TRUE(const_thing.ok()); + EXPECT_EQ(0, *const_thing.ValueOrDie()); + + // Test rvalue converting assignment + const int* const_previous = const_thing.ValueOrDie().get(); + const_thing = ReturnUniquePtr(); + EXPECT_TRUE(const_thing.ok()); + EXPECT_EQ(0, *const_thing.ValueOrDie()); + EXPECT_NE(const_previous, const_thing.ValueOrDie().get()); +} + +TEST(StatusOr, TestMoveOnlyVector) { + // Sanity check that StatusOr works in vector. + std::vector>> vec; + vec.push_back(ReturnUniquePtr()); + vec.resize(2); + auto another_vec = std::move(vec); + EXPECT_EQ(0, *another_vec[0].ValueOrDie()); + EXPECT_EQ(tensorflow::error::UNKNOWN, another_vec[1].status().code()); +} + +TEST(StatusOr, TestMoveWithValuesAndErrors) { + StatusOr status_or(string(1000, '0')); + StatusOr value1(string(1000, '1')); + StatusOr value2(string(1000, '2')); + StatusOr error1(Status(tensorflow::error::UNKNOWN, "error1")); + StatusOr error2(Status(tensorflow::error::UNKNOWN, "error2")); + + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with another value. + status_or = std::move(value1); + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with an error. + status_or = std::move(error1); + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error1", status_or.status().error_message()); + + // Overwrite the error in status_or with another error. + status_or = std::move(error2); + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error2", status_or.status().error_message()); + + // Overwrite the error with a value. + status_or = std::move(value2); + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie()); +} + +TEST(StatusOr, TestCopyWithValuesAndErrors) { + StatusOr status_or(string(1000, '0')); + StatusOr value1(string(1000, '1')); + StatusOr value2(string(1000, '2')); + StatusOr error1(Status(tensorflow::error::UNKNOWN, "error1")); + StatusOr error2(Status(tensorflow::error::UNKNOWN, "error2")); + + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with another value. + status_or = value1; + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie()); + + // Overwrite the value in status_or with an error. + status_or = error1; + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error1", status_or.status().error_message()); + + // Overwrite the error in status_or with another error. + status_or = error2; + ASSERT_FALSE(status_or.ok()); + EXPECT_EQ("error2", status_or.status().error_message()); + + // Overwrite the error with a value. + status_or = value2; + ASSERT_TRUE(status_or.ok()); + EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie()); + + // Verify original values unchanged. + EXPECT_EQ(string(1000, '1'), value1.ValueOrDie()); + EXPECT_EQ("error1", error1.status().error_message()); + EXPECT_EQ("error2", error2.status().error_message()); + EXPECT_EQ(string(1000, '2'), value2.ValueOrDie()); +} + +TEST(StatusOr, TestDefaultCtor) { + StatusOr thing; + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status().code(), tensorflow::error::UNKNOWN); +} + +TEST(StatusOrDeathTest, TestDefaultCtorValue) { + StatusOr thing; + EXPECT_DEATH(thing.ValueOrDie(), ""); + + const StatusOr thing2; + EXPECT_DEATH(thing.ValueOrDie(), ""); +} + +TEST(StatusOr, TestStatusCtor) { + StatusOr thing(Status(tensorflow::error::CANCELLED, "")); + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status().code(), tensorflow::error::CANCELLED); +} + +TEST(StatusOr, TestValueCtor) { + const int kI = 4; + const StatusOr thing(kI); + EXPECT_TRUE(thing.ok()); + EXPECT_EQ(kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestCopyCtorStatusOk) { + const int kI = 4; + const StatusOr original(kI); + const StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie()); +} + +TEST(StatusOr, TestCopyCtorStatusNotOk) { + StatusOr original(Status(tensorflow::error::CANCELLED, "")); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestCopyCtorNonAssignable) { + const int kI = 4; + CopyNoAssign value(kI); + StatusOr original(value); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); +} + +TEST(StatusOr, TestCopyCtorStatusOKConverting) { + const int kI = 4; + StatusOr original(kI); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_DOUBLE_EQ(original.ValueOrDie(), copy.ValueOrDie()); +} + +TEST(StatusOr, TestCopyCtorStatusNotOkConverting) { + StatusOr original(Status(tensorflow::error::CANCELLED, "")); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestAssignmentStatusOk) { + const int kI = 4; + StatusOr source(kI); + StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); + EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie()); +} + +TEST(StatusOr, TestAssignmentStatusNotOk) { + StatusOr source(Status(tensorflow::error::CANCELLED, "")); + StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); +} + +TEST(StatusOr, TestStatus) { + StatusOr good(4); + EXPECT_TRUE(good.ok()); + StatusOr bad(Status(tensorflow::error::CANCELLED, "")); + EXPECT_FALSE(bad.ok()); + EXPECT_EQ(bad.status(), Status(tensorflow::error::CANCELLED, "")); +} + +TEST(StatusOr, TestValue) { + const int kI = 4; + StatusOr thing(kI); + EXPECT_EQ(kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestValueConst) { + const int kI = 4; + const StatusOr thing(kI); + EXPECT_EQ(kI, thing.ValueOrDie()); +} + +TEST(StatusOrDeathTest, TestValueNotOk) { + StatusOr thing(Status(tensorflow::error::CANCELLED, "cancelled")); + EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); +} + +TEST(StatusOrDeathTest, TestValueNotOkConst) { + const StatusOr thing(Status(tensorflow::error::UNKNOWN, "")); + EXPECT_DEATH(thing.ValueOrDie(), ""); +} + +TEST(StatusOr, TestPointerDefaultCtor) { + StatusOr thing; + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status().code(), tensorflow::error::UNKNOWN); +} + +TEST(StatusOrDeathTest, TestPointerDefaultCtorValue) { + StatusOr thing; + EXPECT_DEATH(thing.ValueOrDie(), ""); +} + +TEST(StatusOr, TestPointerStatusCtor) { + StatusOr thing(Status(tensorflow::error::CANCELLED, "")); + EXPECT_FALSE(thing.ok()); + EXPECT_EQ(thing.status(), Status(tensorflow::error::CANCELLED, "")); +} + +TEST(StatusOr, TestPointerValueCtor) { + const int kI = 4; + StatusOr thing(&kI); + EXPECT_TRUE(thing.ok()); + EXPECT_EQ(&kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusOk) { + const int kI = 0; + StatusOr original(&kI); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusNotOk) { + StatusOr original(Status(tensorflow::error::CANCELLED, "")); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusOKConverting) { + Derived derived; + StatusOr original(&derived); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); + EXPECT_EQ(static_cast(original.ValueOrDie()), + copy.ValueOrDie()); +} + +TEST(StatusOr, TestPointerCopyCtorStatusNotOkConverting) { + StatusOr original(Status(tensorflow::error::CANCELLED, "")); + StatusOr copy(original); + EXPECT_EQ(copy.status(), original.status()); +} + +TEST(StatusOr, TestPointerAssignmentStatusOk) { + const int kI = 0; + StatusOr source(&kI); + StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); + EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie()); +} + +TEST(StatusOr, TestPointerAssignmentStatusNotOk) { + StatusOr source(Status(tensorflow::error::CANCELLED, "")); + StatusOr target; + target = source; + EXPECT_EQ(target.status(), source.status()); +} + +TEST(StatusOr, TestPointerStatus) { + const int kI = 0; + StatusOr good(&kI); + EXPECT_TRUE(good.ok()); + StatusOr bad(Status(tensorflow::error::CANCELLED, "")); + EXPECT_EQ(bad.status(), Status(tensorflow::error::CANCELLED, "")); +} + +TEST(StatusOr, TestPointerValue) { + const int kI = 0; + StatusOr thing(&kI); + EXPECT_EQ(&kI, thing.ValueOrDie()); +} + +TEST(StatusOr, TestPointerValueConst) { + const int kI = 0; + const StatusOr thing(&kI); + EXPECT_EQ(&kI, thing.ValueOrDie()); +} + +// NOTE(tucker): tensorflow::StatusOr does not support this kind +// of resize op. +// TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { +// using EvilType = std::vector>; +// static_assert(std::is_copy_constructible::value, ""); +// std::vector> v(5); +// v.reserve(v.capacity() + 10); +// } + +TEST(StatusOrDeathTest, TestPointerValueNotOk) { + StatusOr thing(Status(tensorflow::error::CANCELLED, "cancelled")); + EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); +} + +TEST(StatusOrDeathTest, TestPointerValueNotOkConst) { + const StatusOr thing(Status(tensorflow::error::CANCELLED, "cancelled")); + EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); +} + +static StatusOr MakeStatus() { return 100; } +// A factory to help us benchmark the various factory styles. All of +// the factory methods are marked as non-inlineable so as to more +// accurately simulate calling a factory for which you do not have +// visibility of implementation. Similarly, the value_ variable is +// marked volatile to prevent the compiler from getting too clever +// about detecting that the same value is used in all loop iterations. +template +class BenchmarkFactory { + public: + // Construct a new factory. Allocate an object which will always + // be the result of the factory methods. + BenchmarkFactory() : value_(new T) {} + + // Destroy this factory, including the result value. + ~BenchmarkFactory() { delete value_; } + + // A trivial factory that just returns the value. There is no status + // object that could be returned to encapsulate an error + T* TrivialFactory() TF_ATTRIBUTE_NOINLINE { return value_; } + + // A more sophisticated factory, which returns a status to indicate + // the result of the operation. The factory result is populated into + // the user provided pointer result. + Status ArgumentFactory(T** result) TF_ATTRIBUTE_NOINLINE { + *result = value_; + return Status::OK(); + } + + Status ArgumentFactoryFail(T** result) TF_ATTRIBUTE_NOINLINE { + *result = NULL; + return Status(tensorflow::error::CANCELLED, ""); + } + + Status ArgumentFactoryFailShortMsg(T** result) TF_ATTRIBUTE_NOINLINE { + *result = NULL; + return Status(::tensorflow::error::INTERNAL, ""); + } + + Status ArgumentFactoryFailLongMsg(T** result) TF_ATTRIBUTE_NOINLINE { + *result = NULL; + return Status(::tensorflow::error::INTERNAL, + "a big string of message junk that will never be read"); + } + + // A factory that returns a StatusOr. If the factory operation + // is OK, then the StatusOr will hold a T*. Otherwise, it will + // hold a status explaining the error. + StatusOr StatusOrFactory() TF_ATTRIBUTE_NOINLINE { + return static_cast(value_); + } + + StatusOr StatusOrFactoryFail() TF_ATTRIBUTE_NOINLINE { + return Status(tensorflow::error::CANCELLED, ""); + } + + StatusOr StatusOrFactoryFailShortMsg() TF_ATTRIBUTE_NOINLINE { + return Status(::tensorflow::error::INTERNAL, ""); + } + + StatusOr StatusOrFactoryFailLongMsg() TF_ATTRIBUTE_NOINLINE { + return Status(::tensorflow::error::INTERNAL, + "a big string of message junk that will never be read"); + } + + private: + T* volatile value_; + TF_DISALLOW_COPY_AND_ASSIGN(BenchmarkFactory); +}; + +// A simple type we use with the factory. +class BenchmarkType { + public: + BenchmarkType() {} + virtual ~BenchmarkType() {} + virtual void DoWork() TF_ATTRIBUTE_NOINLINE {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(BenchmarkType); +}; + +// Calibrate the amount of time spent just calling DoWork, since each of our +// tests will do this, we can subtract this out of benchmark results. +static void BM_CalibrateWorkLoop(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + BenchmarkType* result = factory.TrivialFactory(); + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + if (result != NULL) result->DoWork(); + } +} +BENCHMARK(BM_CalibrateWorkLoop); + +// Measure the time taken to call into the factory, return the value, +// determine that it is OK, and invoke a trivial function. +static void BM_TrivialFactory(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + BenchmarkType* result = factory.TrivialFactory(); + if (result != NULL) result->DoWork(); + } +} +BENCHMARK(BM_TrivialFactory); + +// Measure the time taken to call into the factory, providing an +// out-param for the result, evaluating the status result and the +// result pointer, and invoking the trivial function. +static void BM_ArgumentFactory(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + BenchmarkType* result = NULL; + Status status = factory.ArgumentFactory(&result); + if (status.ok() && result != NULL) { + result->DoWork(); + } + } +} +BENCHMARK(BM_ArgumentFactory); + +// Measure the time to use the StatusOr factory, evaluate the result, +// and invoke the trivial function. +static void BM_StatusOrFactory(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + StatusOr result = factory.StatusOrFactory(); + if (result.ok()) { + result.ValueOrDie()->DoWork(); + } + } +} +BENCHMARK(BM_StatusOrFactory); + +// Measure the time taken to call into the factory, providing an +// out-param for the result, evaluating the status result and the +// result pointer, and invoking the trivial function. +static void BM_ArgumentFactoryFail(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + BenchmarkType* result = NULL; + Status status = factory.ArgumentFactoryFail(&result); + if (status.ok() && result != NULL) { + result->DoWork(); + } + } +} +BENCHMARK(BM_ArgumentFactoryFail); + +// Measure the time to use the StatusOr factory, evaluate the result, +// and invoke the trivial function. +static void BM_StatusOrFactoryFail(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + StatusOr result = factory.StatusOrFactoryFail(); + if (result.ok()) { + result.ValueOrDie()->DoWork(); + } + } +} +BENCHMARK(BM_StatusOrFactoryFail); + +// Measure the time taken to call into the factory, providing an +// out-param for the result, evaluating the status result and the +// result pointer, and invoking the trivial function. +static void BM_ArgumentFactoryFailShortMsg(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + BenchmarkType* result = NULL; + Status status = factory.ArgumentFactoryFailShortMsg(&result); + if (status.ok() && result != NULL) { + result->DoWork(); + } + } +} +BENCHMARK(BM_ArgumentFactoryFailShortMsg); + +// Measure the time to use the StatusOr factory, evaluate the result, +// and invoke the trivial function. +static void BM_StatusOrFactoryFailShortMsg(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + StatusOr result = factory.StatusOrFactoryFailShortMsg(); + if (result.ok()) { + result.ValueOrDie()->DoWork(); + } + } +} +BENCHMARK(BM_StatusOrFactoryFailShortMsg); + +// Measure the time taken to call into the factory, providing an +// out-param for the result, evaluating the status result and the +// result pointer, and invoking the trivial function. +static void BM_ArgumentFactoryFailLongMsg(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + BenchmarkType* result = NULL; + Status status = factory.ArgumentFactoryFailLongMsg(&result); + if (status.ok() && result != NULL) { + result->DoWork(); + } + } +} +BENCHMARK(BM_ArgumentFactoryFailLongMsg); + +// Measure the time to use the StatusOr factory, evaluate the result, +// and invoke the trivial function. +static void BM_StatusOrFactoryFailLongMsg(int iters) { + tensorflow::testing::StopTiming(); + BenchmarkFactory factory; + tensorflow::testing::StartTiming(); + for (int i = 0; i != iters; ++i) { + StatusOr result = factory.StatusOrFactoryFailLongMsg(); + if (result.ok()) { + result.ValueOrDie()->DoWork(); + } + } +} +BENCHMARK(BM_StatusOrFactoryFailLongMsg); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/test_helpers.cc b/tensorflow/compiler/xla/test_helpers.cc new file mode 100644 index 0000000000..02abfdeab8 --- /dev/null +++ b/tensorflow/compiler/xla/test_helpers.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace testing { + +AssertionResult::AssertionResult(const AssertionResult& other) + : success_(other.success_), + message_(other.message_ != nullptr ? new std::string(*other.message_) + : static_cast(nullptr)) { +} + +// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. +AssertionResult AssertionResult::operator!() const { + AssertionResult negation(!success_); + if (message_ != nullptr) negation << *message_; + return negation; +} + +AssertionResult& AssertionResult::operator=(const AssertionResult& ar) { + success_ = ar.success_; + message_.reset(ar.message_ != nullptr ? new std::string(*ar.message_) + : nullptr); + return *this; +} + +AssertionResult AssertionFailure() { return AssertionResult(false); } + +AssertionResult AssertionSuccess() { return AssertionResult(true); } + +std::function ContainsRegex( + const tensorflow::StringPiece regex) { + return [regex](const tensorflow::StringPiece to_test) { + if (RE2::PartialMatch( + tensorflow::RegexpStringPiece(to_test.data(), to_test.size()), + tensorflow::RegexpStringPiece(regex.data(), regex.size()))) { + return true; + } else { + LOG(ERROR) << "Expected to find " << regex << " in " << to_test; + return false; + } + }; +} + +std::function HasSubstr( + const tensorflow::StringPiece part) { + return [part](const tensorflow::StringPiece whole) { + return whole.contains(part); + }; +} + +} // namespace testing +} // namespace xla diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h new file mode 100644 index 0000000000..f923d9f36c --- /dev/null +++ b/tensorflow/compiler/xla/test_helpers.h @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ +#define TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ + +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test.h" + +// This module contains a minimal subset of gmock functionality just +// sufficient to execute the currently existing tests. +namespace util { +class Status; +} // namespace util + +namespace xla { +template +class Array2D; +class Literal; + +namespace testing { + +class AssertionResult { + public: + explicit AssertionResult(bool success) : success_(success) {} + + // Returns true iff the assertion succeeded. + operator bool() const { return success_; } // NOLINT + + // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. + AssertionResult operator!() const; + + // Returns the text streamed into this AssertionResult. Test assertions + // use it when they fail (i.e., the predicate's outcome doesn't match the + // assertion's expectation). When nothing has been streamed into the + // object, returns an empty string. + const char* message() const { + return message_ != nullptr ? message_->c_str() : ""; + } + + // Streams a custom failure message into this object. + template + AssertionResult& operator<<(const T& value) { + AppendMessage(::testing::Message() << value); + return *this; + } + + // Allows streaming basic output manipulators such as endl or flush into + // this object. + AssertionResult& operator<<( + std::ostream& (*basic_manipulator)(std::ostream& stream)) { + AppendMessage(::testing::Message() << basic_manipulator); + return *this; + } + + // Copy operator. + AssertionResult(const AssertionResult& ar); + + // Assignment operator. + AssertionResult& operator=(const AssertionResult&); + + private: + // Appends the contents of message to message_. + void AppendMessage(const ::testing::Message& a_message) { + if (message_ == nullptr) message_.reset(new std::string); + message_->append(a_message.GetString().c_str()); + } + + bool success_ = false; + + // Stores the message describing the condition in case the + // expectation construct is not satisfied with the predicate's + // outcome. Referenced via a pointer to avoid taking too much stack + // frame space with test assertions. + std::unique_ptr message_; +}; + +AssertionResult AssertionFailure(); + +AssertionResult AssertionSuccess(); + +std::function ContainsRegex( + const tensorflow::StringPiece regex); + +std::function HasSubstr( + const tensorflow::StringPiece part); + +// Matcher for a vector of same-type values for which operator= is +// defined. +template +std::function& actual)> VectorMatcher( + const std::vector& expected) { + return [expected](const std::vector& actual) -> AssertionResult { + int len = expected.size(); + if (actual.size() != len) { + return AssertionFailure() << "Actual values len of " << actual.size() + << " != expected.size " << len; + } + for (int i = 0; i < len; ++i) { + if (actual[i] != expected[i]) { + return AssertionFailure() << "Element " << i << " actual " << actual[i] + << " != " << expected[i]; + } + } + return AssertionSuccess(); + }; +} + +// Approximate matcher for a vector of floats or similar. +template +std::function& actual)> +ApproxVectorMatcher(const std::vector& expected, float abs_diff, + float rel_diff) { + return [abs_diff, rel_diff, + expected](const std::vector& actual) -> AssertionResult { + int len = expected.size(); + if (actual.size() != len) { + AssertionResult ar = AssertionFailure() << "Actual values len of " + << actual.size() + << " != expected.size " << len; + LOG(ERROR) << ar.message(); + return ar; + } + for (int i = 0; i < len; ++i) { + T diff = actual[i] - expected[i]; + if (diff < 0) { + diff *= -1; + } + if (diff > abs_diff) { + T rdiff = (expected[i] != 0 ? diff / expected[i] : 0.0 * expected[i]); + if (rdiff > rel_diff) { + AssertionResult ar = AssertionFailure() + << "Element " << i << " actual " << actual[i] + << " != " << expected[i] + << "( abs_diff = " << diff + << ", rel_diff = " << rdiff << ")"; + LOG(ERROR) << ar.message(); + return ar; + } + } + } + return AssertionSuccess(); + }; +} + +// Matches a vector of same-type values against another, succeeding so +// long as they have the same length and every value in 'actual' +// matches one in 'expected.' Does not verify an exhaustive +// one-to-one mapping between the two. +template +std::function& actual)> +UnorderedElementsAre(const std::vector& expected) { + return [expected](const std::vector& actual) -> AssertionResult { + if (actual.size() != expected.size()) { + return AssertionFailure() << "sizes don't match"; + } + for (auto a : actual) { + bool found = false; + for (auto e : expected) { + if (a == e) { + found = true; + break; + } + } + if (!found) { + return AssertionFailure() << "actual element " << a + << " not in expected"; + } + } + return AssertionSuccess(); + }; +} + +// Overloaded cover functions for UnorderedElementsAre, for the numbers +// of values used in practice. +template +std::function& actual)> UnorderedMatcher( + T a) { + std::vector expected; + expected.push_back(a); + return testing::UnorderedElementsAre(expected); +} + +template +std::function& actual)> UnorderedMatcher( + T a, T b) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + return testing::UnorderedElementsAre(expected); +} + +template +std::function& actual)> UnorderedMatcher( + T a, T b, T c) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + return testing::UnorderedElementsAre(expected); +} + +template +std::function& actual)> UnorderedMatcher( + T a, T b, T c, T d) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + return testing::UnorderedElementsAre(expected); +} + +template +std::function& actual)> UnorderedMatcher( + T a, T b, T c, T d, T e) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + expected.push_back(e); + return testing::UnorderedElementsAre(expected); +} + +template +std::function& actual)> UnorderedMatcher( + T a, T b, T c, T d, T e, T f) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + expected.push_back(e); + expected.push_back(f); + return testing::UnorderedElementsAre(expected); +} + +// Overloaded cover functions for VectorMatcher for the numbers of +// elements used in practice. +template +std::function& actual)> OrderedMatcher( + T a) { + std::vector expected; + expected.push_back(a); + return testing::VectorMatcher(expected); +} + +template +std::function& actual)> OrderedMatcher( + T a, T b) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + return testing::VectorMatcher(expected); +} + +template +std::function& actual)> OrderedMatcher( + T a, T b, T c) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + return testing::VectorMatcher(expected); +} + +template +std::function& actual)> OrderedMatcher( + T a, T b, T c, T d) { + std::vector expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + return testing::VectorMatcher(expected); +} + +// Convert a RepeatedField to a flat vector. +template +std::vector PBToVec(const tensorflow::protobuf::RepeatedField rf) { + return std::vector(rf.begin(), rf.end()); +} + +// Convert a List to a flat vector. +template +std::vector ListToVec(const std::list& l) { + return std::vector(l.begin(), l.end()); +} + +// Convert a Set to a flat vector. +template +std::vector SetToVec(const std::set& c) { + return std::vector(c.begin(), c.end()); +} + +// Convert an Array to a flat vector. +template +std::vector Array2DToVec(const Array2D& a) { + return std::vector(a.data(), a.data() + a.num_elements()); +} + +namespace internal_status { +inline const ::tensorflow::Status& GetStatus( + const ::tensorflow::Status& status) { + return status; +} + +template +inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { + return status.status(); +} +} // namespace internal_status + +} // namespace testing +} // namespace xla + +// The following macros are similar to macros in gmock, but deliberately named +// differently in order to avoid conflicts in files which include both. + +// Macros for testing the results of functions that return tensorflow::Status or +// StatusOr (for any type T). +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(tensorflow::Status::OK(), \ + xla::testing::internal_status::GetStatus(expression)) +#undef ASSERT_IS_OK +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(tensorflow::Status::OK(), \ + xla::testing::internal_status::GetStatus(expression)) + +// Macros that apply a Matcher to a Value, returning an +// AssertionResult which gets digested by a standard gunit macro. +#define EXPECT_MATCH(V, M) EXPECT_TRUE((M)((V))) +#define ASSERT_MATCH(V, M) ASSERT_TRUE(M(V)) + +#endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD new file mode 100644 index 0000000000..93fe1fee4a --- /dev/null +++ b/tensorflow/compiler/xla/tests/BUILD @@ -0,0 +1,1436 @@ +# Description: +# Base testing infrastructure for XLA. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], + features = ["no_layering_check"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + +cc_library( + name = "test_macros_header", + testonly = True, + hdrs = ["test_macros.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:test", + ], +) + +# Generate a test_macros_${BACKEND} library per backend with the proper copts. +generate_backend_test_macros() + +cc_library( + name = "test_utils", + testonly = True, + hdrs = ["test_utils.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "literal_test_util", + testonly = True, + srcs = ["literal_test_util.cc"], + hdrs = ["literal_test_util.h"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_test_base", + testonly = True, + srcs = ["hlo_test_base.cc"], + hdrs = ["hlo_test_base.h"], + deps = [ + ":literal_test_util", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + "//third_party/eigen3", + ], +) + +cc_binary( + name = "local_client_aot_test_helper", + srcs = ["local_client_aot_test_helper.cc"], + deps = [ + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +genrule( + name = "local_client_aot_test_computation", + outs = ["local_client_aot_test_computation.o"], + cmd = "$(location :local_client_aot_test_helper) $(TARGET_CPU) > $(OUTS)", + local = 1, + tools = [":local_client_aot_test_helper"], +) + +cc_library( + name = "client_library_test_base", + testonly = True, + srcs = ["client_library_test_base.cc"], + hdrs = ["client_library_test_base.h"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "codegen_test_base", + testonly = True, + srcs = ["codegen_test_base.cc"], + hdrs = ["codegen_test_base.h"], + data = [ + "@llvm//:FileCheck", + ], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "local_client_test_base", + testonly = True, + srcs = ["local_client_test_base.cc"], + hdrs = ["local_client_test_base.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +xla_test( + name = "bad_rng_shape_validation_test", + srcs = ["bad_rng_shape_validation_test.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "check_execution_arity_test", + srcs = ["check_execution_arity_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "query_inferred_shape_test", + srcs = ["query_inferred_shape_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "while_test", + srcs = ["while_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "axpy_simple_test", + srcs = ["axpy_simple_test.cc"], + deps = [ + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "map_test", + srcs = ["map_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "params_test", + srcs = ["params_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "pred_test", + srcs = ["pred_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "select_test", + srcs = ["select_test.cc"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "unary_op_test", + srcs = ["unary_op_test.cc"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "scalar_computations_test", + srcs = ["scalar_computations_test.cc"], + shard_count = 16, + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "deallocation_test", + srcs = ["deallocation_test.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "deconstruct_tuple_test", + srcs = ["deconstruct_tuple_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "array_elementwise_ops_test", + srcs = ["array_elementwise_ops_test.cc"], + # This test includes comparisons to NAN, so disable fast-math. + backend_args = { + "cpu": ["--xla_fast_math=false"], + "cpu_parallel": ["--xla_fast_math=false"], + "gpu": ["--xla_fast_math=false"], + }, + shard_count = 25, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "dot_operation_test", + srcs = ["dot_operation_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +# Tests the dot operation in some cases that can be performed via a +# runtime call on some backends - e.g. a runtime call to to Eigen. +xla_test( + name = "dot_operation_runtime_test", + srcs = ["dot_operation_test.cc"], + backend_args = { + "cpu": ["--xla_cpu_use_eigen"], + "cpu_parallel": ["--xla_cpu_use_eigen"], + }, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +# Repeat dot_operation_runtime_test with single-threded eigen. +xla_test( + name = "dot_operation_single_threaded_runtime_test", + srcs = ["dot_operation_test.cc"], + backend_args = { + "cpu": [ + "--xla_cpu_use_eigen", + "--xla_cpu_multi_thread_eigen=false", + ], + "cpu_parallel": [ + "--xla_cpu_use_eigen", + "--xla_cpu_multi_thread_eigen=false", + ], + }, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "dot_operation_rowmajor_runtime_test", + srcs = ["dot_operation_test.cc"], + backend_args = { + "cpu": [ + "--xla_cpu_use_eigen", + "--xla_default_layout=major2minor", + ], + "cpu_parallel": [ + "--xla_cpu_use_eigen", + "--xla_default_layout=major2minor", + ], + }, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "transpose_test", + srcs = ["transpose_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "constants_test", + srcs = ["constants_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "convolution_test", + timeout = "long", + srcs = ["convolution_test.cc"], + shard_count = 25, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "convolution_variants_test", + timeout = "long", + srcs = ["convolution_variants_test.cc"], + backend_tags = { + # TODO(b/31436974): Fix msan failure. Failed on 2016-09-12. + "cpu": ["nomsan"], + "cpu_parallel": ["nomsan"], + }, + shard_count = 30, + deps = [ + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "convolution_dimension_numbers_test", + timeout = "long", + srcs = ["convolution_dimension_numbers_test.cc"], + shard_count = 20, + deps = [ + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "batch_normalization_test", + srcs = ["batch_normalization_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "slice_test", + srcs = ["slice_test.cc"], + shard_count = 40, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "multidimensional_slice_test", + srcs = ["multidimensional_slice_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "dynamic_ops_test", + srcs = ["dynamic_ops_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "tuple_test", + srcs = ["tuple_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "vector_ops_reduce_test", + srcs = ["vector_ops_reduce_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "reduce_test", + srcs = ["reduce_test.cc"], + shard_count = 40, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "reduce_window_test", + timeout = "long", + srcs = ["reduce_window_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "select_and_scatter_test", + timeout = "long", + srcs = ["select_and_scatter_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "copy_test", + srcs = ["copy_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "call_test", + srcs = ["call_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "custom_call_test", + srcs = ["custom_call_test.cc"], + linkopts = export_dynamic_linkopts, + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "binop_scaling_test", + srcs = ["binop_scaling_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "broadcast_simple_test", + srcs = ["broadcast_simple_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "pad_test", + srcs = ["pad_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "fmax_test", + srcs = ["fmax_test.cc"], + deps = [ + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "log_test", + srcs = ["log_test.cc"], + deps = [ + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "matrix_ops_simple_test", + srcs = ["matrix_ops_simple_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "prng_test", + srcs = ["prng_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "reshape_test", + srcs = ["reshape_test.cc"], + shard_count = 30, + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "reverse_test", + srcs = ["reverse_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "vector_ops_simple_test", + srcs = ["vector_ops_simple_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "concat_test", + srcs = ["concat_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "convert_test", + srcs = ["convert_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "compilation_cache_test", + srcs = ["compilation_cache_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "floor_ceil_test", + srcs = ["floor_ceil_test.cc"], + deps = [ + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "compute_constant_test", + srcs = ["compute_constant_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "client_test", + srcs = ["client_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "inprocess_service_test", + srcs = ["inprocess_service_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "replay_test", + srcs = ["replay_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "broadcast_test", + srcs = ["broadcast_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "round_trip_packed_literal_test", + srcs = ["round_trip_packed_literal_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:packed_literal_reader", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "fusion_test", + srcs = ["fusion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_test( + name = "local_client_aot_test", + srcs = [ + "local_client_aot_test.cc", + ":local_client_aot_test_computation.o", + ], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +xla_test( + name = "round_trip_transfer_test", + srcs = ["round_trip_transfer_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "set_return_value_test", + srcs = ["set_return_value_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "reshape_motion_test", + srcs = ["reshape_motion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_test( + name = "literal_test_util_test", + srcs = ["literal_test_util_test.cc"], + deps = [ + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc new file mode 100644 index 0000000000..cf6f9a825c --- /dev/null +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -0,0 +1,1662 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ArrayElementwiseOpTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +class ArrayElementwiseOpTestParamCount + : public ArrayElementwiseOpTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto result = builder.Neg(a); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, NegConstantF32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto result = builder.Neg(a); + + ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, NegConstantS32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-1, 0, 1, 324, + std::numeric_limits::min(), + std::numeric_limits::max()}); + auto result = builder.Neg(a); + + // -min == min for int32 due to an overflow. In C++ it is undefined behavior + // to do this calculation. For XLA we have not specified that, so it + // ought to work. + ComputeAndCompareR1(&builder, + {1, 0, -1, -324, std::numeric_limits::min(), + -std::numeric_limits::max()}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { + const int count = GetParam(); + ComputationBuilder builder(client_, TestName()); + std::vector a_values; + std::vector b_values; + for (int i = 0; i < count; ++i) { + a_values.push_back(i / static_cast(count)); + b_values.push_back(2 * i / static_cast(count + 2)); + } + + std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a_constant = builder.ConstantR1(a_values); + auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); + + std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + std::unique_ptr b_data = + client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); + auto b_param = builder.ConstantR1(b_values); + + auto sum1 = builder.Add(a_constant, b_constant); + auto sum2 = builder.Add(a_constant, b_param); + auto sum3 = builder.Add(a_param, b_constant); + auto sum4 = builder.Add(a_param, b_param); + + auto sum = builder.Add(sum1, sum2); + sum = builder.Add(sum, sum3); + sum = builder.Add(sum, sum4); + + std::vector expected; + for (int64 i = 0; i < count; ++i) { + expected.push_back(4 * (a_values[i] + b_values[i])); + } + + ComputeAndCompareR1(&builder, expected, {a_data.get(), b_data.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); + auto b = builder.ConstantR1({-1, 2, 1, -1}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); + auto add = builder.Div(a, b); + + ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Div(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); + auto b = builder.ConstantR1( + {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1( + &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); + auto b = builder.ConstantR1( + {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1( + &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { + std::vector data = {0, + 1, + -1, + 1234, + 0x1a243514, + std::numeric_limits::max(), + std::numeric_limits::min()}; + // Form the test data set using all products of 'data' with itself. + std::vector a_data, b_data, expected; + for (int32 a : data) { + for (int32 b : data) { + a_data.push_back(a); + b_data.push_back(b); + expected.push_back(static_cast(a) * static_cast(b)); + } + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1(a_data); + auto b = builder.ConstantR1(b_data); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { + std::vector data = {0, 1, 0xDEADBEEF, 1234, + 0x1a243514, 0xFFFFFFFF, 0x80808080}; + + // Form the test data set using all products of 'data' with itself. + std::vector a_data, b_data, expected; + for (uint32 a : data) { + for (uint32 b : data) { + a_data.push_back(a); + b_data.push_back(b); + expected.push_back(a * b); + } + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1(a_data); + auto b = builder.ConstantR1(b_data); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, LogicalAnd) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({false, false, true, true}); + auto b = builder.ConstantR1({false, true, false, true}); + auto out = builder.LogicalAnd(a, b); + + ComputeAndCompareR1(&builder, {false, false, false, true}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.LogicalAnd(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, LogicalOr) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({false, false, true, true}); + auto b = builder.ConstantR1({false, true, false, true}); + auto out = builder.LogicalOr(a, b); + + ComputeAndCompareR1(&builder, {false, true, true, true}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.LogicalOr(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, LogicalNot) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({false, true, true, false}); + auto out = builder.LogicalNot(a); + + ComputeAndCompareR1(&builder, {true, false, false, true}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.LogicalNot(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Ge(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Gt(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Le(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Lt(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1( + &builder, {true, false, false, false, true, false, false, false, true}, + {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1( + &builder, {false, true, true, true, false, true, true, true, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Ge(lhs, rhs); + + ComputeAndCompareR1( + &builder, {true, false, false, true, true, false, true, true, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Gt(lhs, rhs); + + ComputeAndCompareR1( + &builder, {false, false, false, true, false, false, true, true, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Le(lhs, rhs); + + ComputeAndCompareR1( + &builder, {true, true, true, false, true, true, false, false, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Lt(lhs, rhs); + + ComputeAndCompareR1( + &builder, {false, true, true, false, false, true, false, false, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1( + &builder, {true, false, false, false, true, false, false, false, true}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1( + &builder, {false, true, true, true, false, true, true, true, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Ge(lhs, rhs); + + ComputeAndCompareR1( + &builder, {true, false, false, true, true, false, true, true, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Gt(lhs, rhs); + + ComputeAndCompareR1( + &builder, {false, false, false, true, false, false, true, true, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Le(lhs, rhs); + + ComputeAndCompareR1( + &builder, {true, true, true, false, true, true, false, false, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Lt(lhs, rhs); + + ComputeAndCompareR1( + &builder, {false, true, true, false, false, true, false, false, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, PowF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN}); + auto minimum = builder.Pow(lhs, rhs); + + ComputeAndCompareR1(&builder, {16.0f, 0.25f, 8.0f, NAN, NAN}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto minimum = builder.Pow(lhs, rhs); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +// Some Pow cases that can be implemented more efficiently. +TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; + std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + std::unique_ptr param_data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + + auto sum = b.ConstantR0(0.0f); + auto param = b.Parameter(0, param_literal->shape(), "param"); + for (float exponent : exponents) { + sum = b.Add(sum, b.Pow(param, b.ConstantR0(exponent))); + } + + std::vector expected; + for (auto value : values) { + float sum = 0.0f; + for (float exponent : exponents) { + sum += std::pow(value, exponent); + } + expected.push_back(sum); + } + + ComputeAndCompareR1(&b, expected, {param_data.get()}, error_spec_); +} + +TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { + const int count = GetParam(); + ComputationBuilder builder(client_, TestName()); + std::vector values; + for (int i = 0; i < count; ++i) { + values.push_back(i / static_cast(count)); + } + auto x = builder.ConstantR1(values); + auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + + std::vector expected; + for (float value : values) { + expected.push_back(value * value); + } + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, SquareIn4D) { + ComputationBuilder builder(client_, TestName()); + Array4D values(2, 2, 2, 2); + + std::vector values_vector; + std::vector expected_vector; + for (int i = 0; i < values.num_elements(); ++i) { + values_vector.push_back(static_cast(i) / values.num_elements()); + expected_vector.push_back(values_vector.back() * values_vector.back()); + } + values.SetValues(values_vector); + + Array4D expected(2, 2, 2, 2, expected_vector); + + auto x = builder.ConstantR4FromArray4D(values); + auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { + ComputationBuilder builder(client_, TestName()); + Array4D values(2, 2, 0, 2); + Array4D expected(2, 2, 0, 2); + + auto x = builder.ConstantR4FromArray4D(values); + auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT +// such +// * fmin(NaN, x) = x +// * fmax(NaN, x) = x +// so we only test NAN on CPU. +// +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. +TEST_F(ArrayElementwiseOpTest, MinF32s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); + auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); +#else + auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); +#endif + auto minimum = builder.Min(lhs, rhs); + + ComputeAndCompareR1(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {1.0f, -5.0f, 1.0f}, +#else + {1.0f, -5.0f, 1.0f, 10.0f, 6.0f}, +#endif + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto minimum = builder.Min(lhs, rhs); + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. See comment on MinF32s test above. +XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); + auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); +#else + auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); +#endif + auto minimum = builder.Min(lhs, rhs); + + ComputeAndCompareR1(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {1.0, -5.0, 1.0}, +#else + {1.0, -5.0, 1.0, 10.0, 6.0}, +#endif + {}, error_spec_); +} + +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. See comment on MinF32s test above. +TEST_F(ArrayElementwiseOpTest, MaxF32s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); + auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); +#else + auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); +#endif + auto maximum = builder.Max(lhs, rhs); + + ComputeAndCompareR1(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {2.0f, 1.0f, 2.25f}, +#else + {2.0f, 1.0f, 2.25f, 10.0f, 6.0f}, +#endif + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto minimum = builder.Max(lhs, rhs); + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. See comment on MinF32s test above. +XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); + auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); +#else + auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); +#endif + auto maximum = builder.Max(lhs, rhs); + + ComputeAndCompareR1(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {2.0, 1.0, 2.25}, +#else + {2.0, 1.0, 2.25, 10.0, 6.0}, +#endif + {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MaxS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = builder.ConstantR1( + {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + builder.Max(x, y); + + std::vector expected = {min, max, 0, -1, 0, 0, 0, + 1, 1, 10, max, max, max}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MinS32s) { + const int32 min = std::numeric_limits::min(); + const int32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = builder.ConstantR1( + {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + builder.Min(x, y); + + std::vector expected = {min, min, min, -10, -1, -1, 0, + 0, 0, 1, 0, max, min}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MaxU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); + auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); + builder.Max(x, y); + + std::vector expected = {0, 1, 1, 1, 10, max, max, max}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MinU32s) { + const uint32 max = std::numeric_limits::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); + auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); + builder.Min(x, y); + + std::vector expected = {0, 0, 0, 1, 1, 0, 234234, max}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = builder.ConstantR1( + {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + builder.Max(x, y); + + std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { + ComputationBuilder builder(client_, TestName()); + auto u = builder.ConstantR1({3.5}); + auto v = builder.ConstantR1({}); + builder.Max(u, v); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { + for (int broadcast_dim : {0, 1}) { + ComputationBuilder builder(client_, TestName()); + auto u = builder.ConstantR1({3.5}); + auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); + builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); + + ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); + } +} + +TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); + auto m = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + builder.Max(v, m, /*broadcast_dimensions=*/{1}); + + Array2D expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({}); + auto m = builder.ConstantR2({{}, {}}); + builder.Max(v, m, /*broadcast_dimensions=*/{1}); + + Array2D expected({{}, {}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { + ComputationBuilder builder(client_, TestName()); + auto scalar = builder.ConstantR0(2); + Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); + auto array = builder.ConstantR3FromArray3D(a_3d); + builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + + Array3D expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); + ComputeAndCompareR3(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto scalar = builder.ConstantR0(2); + Array3D a_3d(2, 0, 3); + auto array = builder.ConstantR3FromArray3D(a_3d); + builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + + Array3D expected(2, 0, 3); + ComputeAndCompareR3(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { + ComputationBuilder builder(client_, TestName()); + auto m = + builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); + auto v = builder.ConstantR1({-10.2f, 16.4f}); + builder.Min(m, v, /*broadcast_dimensions=*/{0}); + + Array2D expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantR2({{}, {}}); + auto v = builder.ConstantR1({-10.2f, 16.4f}); + builder.Min(m, v, /*broadcast_dimensions=*/{0}); + + Array2D expected({{}, {}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { + ComputationBuilder builder(client_, TestName()); + auto array2d = + builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + auto array4d = builder.ConstantR4FromArray4D( + {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, + {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); + builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + + Array4D expected( + {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, + {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto array2d = + builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + Array4D arg(2, 2, 0, 3); + auto array4d = builder.ConstantR4FromArray4D(arg); + builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + + Array4D expected(2, 2, 0, 3); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + builder.Min(x, y); + + std::vector expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + builder.Max(x, y); + + std::vector expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); + auto b = builder.ConstantR1({10, 5, 1, 10, -10}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); +} + +TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { + ComputationBuilder builder(client_, TestName()); + auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + auto clamp = builder.Clamp(minimum, argument, maximum); + + ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { + ComputationBuilder builder(client_, TestName()); + auto minimum = builder.ConstantR0(0.0f); + auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto maximum = builder.ConstantR0(5.0f); + auto clamp = builder.Clamp(minimum, argument, maximum); + + ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0.0f); + auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto arg_scalar = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto max_scalar = builder.ConstantR0(3.0f); + auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector), + builder.Clamp(min_scalar, arg_scalar, max_vector))); + + ComputeAndCompareR1(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr param1_literal = + LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto add = builder.Add(p0, p1); + + ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, + {param0_data.get(), param1_data.get()}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr param1_literal = + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto add = builder.Add(p0, p1); + + Array3D expected(0, 7, 0); + ComputeAndCompareR3( + &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); + auto p = builder.Parameter(0, param0_literal->shape(), "param0"); + auto add = builder.Add(a, p); + + ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, + {param0_data.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, TanhF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); + auto result = builder.Tanh(a); + + ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { + // a ------ (add) --------- (add) + // / / + // b -----/ / + // c---------------------/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + + auto add = builder.Add(a, b); + auto add2 = builder.Add(add, c); + + ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { + // b ------ (add) --------- (add) + // / / + // c -----/ / + // a---------------------/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + + auto add = builder.Add(b, c); + auto add2 = builder.Add(a, add); + + ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddWithNeg) { + // a ----- (neg) ----- (add) + // / + // b ----- (neg) ----/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + + auto neg_a = builder.Neg(a); + auto neg_b = builder.Neg(b); + auto result = builder.Add(neg_a, neg_b); + + ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { + // a ------ (add) ------------\ + // / \ + // b -----/ (add) + // / + // c ------ (add) ------------/ + // / + // d -----/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto d = builder.ConstantR1({-19.0f, 10.0f, -40.0f, 20.2f}); + + auto add_ab = builder.Add(a, b); + auto add_cd = builder.Add(c, d); + auto add_all = builder.Add(add_ab, add_cd); + + ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = + builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + auto add = builder.Add(a, b); + + Array2D expected_array( + {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { + // Add a scalar + matrix. + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = builder.ConstantR0(3.0f); + auto add = builder.Add(scalar, a); + + Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { + // Add a matrix + scalar. + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = builder.ConstantR0(3.0f); + auto add = builder.Add(a, scalar); + + Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { + // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches + // only dim 0 of the matrix. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); + // clang-format off + auto m = builder.ConstantR2({ + {-2.5f, 3.14f, 1.0f}, + {2.25f, -10.0f, 3.33f}}); + // clang-format on + auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Array2D expected_array( + {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { + // Test broadcasting in Eq comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({42, 73}); + auto m = builder.ConstantR2({{42, 73}, {42, 52}}); + + // This test exercises both possible broadcast dimensions for a vector/matrix + // comparison. + auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); + auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); + auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); + + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), + LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { + // Test broadcasting in Ne comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({42, 73}); + auto m = builder.ConstantR2({{42, 73}, {42, 52}}); + auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,2] { + { 00 }, + { 01 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { + // Test broadcasting in Ge comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({1, 2, 3, 4}); + auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 1100 }, + { 0001 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { + // Test broadcasting in Gt comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({1, 2, 3, 4}); + auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 0100 }, + { 0000 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { + // Test broadcasting in Le comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({1, 2, 3, 4}); + auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 1011 }, + { 1111 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { + // Test broadcasting in Lt comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({1, 2, 3, 4}); + auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 0011 }, + { 1110 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { + // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op + // arguments is reversed. + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); + auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); + auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { + // Tests broadcasting for arrays with degenerate (size == 1) dimensions. + ComputationBuilder builder(client_, TestName()); + // m's shape in XLA notation is {3, 2} + // md's shape in XLA notation is {3, 1} + // The result has shape {3, 2}, where md is broadcast over m + auto m = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); + auto add = builder.Add(m, md); + Array2D expected_array( + {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { + // Tests broadcasting for arrays with degenerate (size == 1) dimensions. + ComputationBuilder builder(client_, TestName()); + // m's shape in XLA notation is {3, 2} + // md's shape in XLA notation is {1, 2} + // The result has shape {3, 2}, where md is broadcast over m + auto m = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = builder.ConstantR2({{10.0f}, {20.0f}}); + auto add = builder.Add(m, md); + Array2D expected_array( + {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { + // Tests broadcasting for two degenerate arrays. This kind of broadcasting + // effectively creates an "outer product" operation. + // This is taken from the Numpy docs example at: + // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html + ComputationBuilder builder(client_, TestName()); + // a's shape in XLA notation is {1, 4} + // b's shape in XLA notation is {3, 1} + // The result has shape {3, 4}. + auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); + auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); + auto add = builder.Add(a, b); + Array2D expected_array({{1.0f, 2.0f, 3.0f}, + {11.0f, 12.0f, 13.0f}, + {21.0f, 22.0f, 23.0f}, + {31.0f, 32.0f, 33.0f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { + // Add together a (2,2) array and a (2) array, using dimension 0 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({20.0f, 40.0f}); + auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); + auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { + // Add together a (2,2) array and a (2) array, using dimension 1 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({20.0f, 40.0f}); + auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); + auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); + Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { + // Binary add of two R3s together + ComputationBuilder builder(client_, TestName()); + Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); + auto a = builder.ConstantR3FromArray3D(a_3d); + + Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, + {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); + auto b = builder.ConstantR3FromArray3D(b_3d); + auto add = builder.Add(a, b); + + Array3D expected_3d( + {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, + {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}}); + ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { + // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array3D a_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + // clang-format on + auto a = builder.ConstantR3FromArray3D(a_3d); + auto v = builder.ConstantR1({10.0f, 20.0f}); + auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); + + Array3D expected_3d( + {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, + {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}}); + ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { + // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array3D a_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + // clang-format on + auto a = builder.ConstantR3FromArray3D(a_3d); + auto v = builder.ConstantR1({10.0f, 20.0f}); + auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); + + // clang-format off + Array3D expected_3d({ + {{11.0f, 12.0f}, + {13.0f, 14.0f}, + {15.0f, 16.0f}}, + {{27.0f, 28.0f}, + {29.0f, 30.0f}, + {31.0f, 32.0f}}, + }); + // clang-format on + ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { + // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} + // for broadcasting. + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array3D a_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + auto a = builder.ConstantR3FromArray3D(a_3d); + auto m = builder.ConstantR2({ + {10.0f, 20.0f, 30.0f}, + {40.0f, 50.0f, 60.0f}, + }); + auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + + Array3D expected_3d({ + {{11.0f, 12.0f}, + {23.0f, 24.0f}, + {35.0f, 36.0f}}, + {{47.0f, 48.0f}, + {59.0f, 60.0f}, + {71.0f, 72.0f}}, + }); + // clang-format on + ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { + // Comparison between two 3D arrays of compatible shapes: + // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. + ComputationBuilder builder(client_, TestName()); + Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); + auto a = builder.ConstantR3FromArray3D(a_3d); + + Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); + auto b = builder.ConstantR3FromArray3D(b_3d); + + auto compare = builder.Gt(a, b); + + Array3D expected_3d( + {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); + const string expected = R"(pred[2,3,2] { +{ { 01 }, + { 00 }, + { 00 } }, +{ { 01 }, + { 10 }, + { 01 } } +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); + std::unique_ptr> operand_b_4d(new Array4D(2, 3, 4, 5)); + std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); + float value = 0.0; + for (int64 p = 0; p < 2; ++p) { + for (int64 z = 0; z < 3; ++z) { + for (int64 y = 0; y < 4; ++y) { + for (int64 x = 0; x < 5; ++x) { + (*operand_a_4d)(p, z, y, x) = value; + (*operand_b_4d)(p, z, y, x) = 2.0 * value; + (*expected_4d)(p, z, y, x) = 3.0 * value; + value += 0.1; + } + } + } + } + + auto a = builder.ConstantR4FromArray4D(*operand_a_4d); + auto b = builder.ConstantR4FromArray4D(*operand_b_4d); + auto add = builder.Add(a, b); + + ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); + std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); + std::vector operand_b_1d(3); + std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0); + + float value = 0.0; + for (int64 p = 0; p < 2; ++p) { + for (int64 z = 0; z < 3; ++z) { + for (int64 y = 0; y < 4; ++y) { + for (int64 x = 0; x < 5; ++x) { + (*operand_a_4d)(p, z, y, x) = value; + (*expected_4d)(p, z, y, x) = value + operand_b_1d[z]; + value += 0.1; + } + } + } + } + + auto a = builder.ConstantR4FromArray4D(*operand_a_4d); + auto b = builder.ConstantR1(operand_b_1d); + auto add = builder.Add(a, b, {1}); + + ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, R4_32x64x2x2_Plus_R1_64) { + constexpr int d0 = 16; + constexpr int d1 = 16; + constexpr int d2 = 2; + constexpr int d3 = 2; + Array4D r4(d0, d1, d2, d3); + r4.Fill(1.0); + std::vector r1(d1); + std::iota(r1.begin(), r1.end(), 1.0); + + ComputationBuilder builder(client_, TestName()); + std::unique_ptr a_literal = LiteralUtil::CreateR4FromArray4D(r4); + *a_literal->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3}); + auto a = builder.ConstantLiteral(*a_literal); + auto b = builder.ConstantR1(r1); + builder.Add(a, b, {1}); + + for (int i0 = 0; i0 < d0; ++i0) { + for (int i1 = 0; i1 < d1; ++i1) { + for (int i2 = 0; i2 < d2; ++i2) { + for (int i3 = 0; i3 < d3; ++i3) { + r4(i0, i1, i2, i3) += r1[i1]; + } + } + } + } + ComputeAndCompareR4(&builder, r4, {}, error_spec_); +} + +// Show that we can't add two opaques. +TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { + ComputationBuilder builder(client_, TestName()); + auto shape = ShapeUtil::MakeOpaqueShape(); + auto x = builder.Parameter(0, shape, "x"); + auto concatenated = builder.Add(x, x); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_MATCH(computation_status.status().ToString(), + testing::ContainsRegex( + "Expected non-opaque argument for lhs of binary operation")); +} + +// Regression test for b/31927799. "slice - y" is fused and requires implicit +// broadcast. +TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { + ComputationBuilder builder(client_, TestName()); + auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); + auto y_literal = LiteralUtil::CreateR1({4, 5}); + auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + + auto x = builder.Parameter(0, x_literal->shape(), "x"); + auto y = builder.Parameter(1, y_literal->shape(), "y"); + auto slice = builder.Slice(x, {1}, {2}); + builder.Sub(slice, y); + + ComputeAndCompareR1(&builder, {-2, -3}, {x_data.get(), y_data.get()}, + error_spec_); +} + +INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, + ArrayElementwiseOpTestParamCount, + ::testing::Values(127, 128, 129, 17 * 4096)); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendLlvmBackendFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc new file mode 100644 index 0000000000..adffac09e3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class AxpySimpleTest : public ClientLibraryTestBase {}; + +TEST_F(AxpySimpleTest, AxTenValues) { + ComputationBuilder builder(client_, "ax_10"); + auto alpha = builder.ConstantR0(3.1415926535); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto ax = builder.Mul(alpha, x); + + std::vector expected = { + -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, + 9.42477796, 12.56637061, -12.56637061, -15.70796327, 15.70796327}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { + ComputationBuilder builder(client_, "axpy_10"); + auto alpha = builder.ConstantR0(3.1415926535); + auto x = builder.ConstantR1({}); + auto y = builder.ConstantR1({}); + auto ax = builder.Mul(alpha, x); + auto axpy = builder.Add(ax, y); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(AxpySimpleTest, AxpyTenValues) { + ComputationBuilder builder(client_, "axpy_10"); + auto alpha = builder.ConstantR0(3.1415926535); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = builder.ConstantR1( + {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = builder.Mul(alpha, x); + auto axpy = builder.Add(ax, y); + + std::vector expected = { + 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, + 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc new file mode 100644 index 0000000000..c7b533b80f --- /dev/null +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that passing a bad shape to RNG's output parameter causes a validation +// failure rather than causing a crash. + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class BadRngShapeValidationTest : public ClientLibraryTestBase {}; + +TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { + ComputationBuilder builder(client_, TestName()); + auto zero = builder.ConstantR0(0.0); + auto one = builder.ConstantR0(1.0); + Shape default_constructed; + builder.RngUniform(zero, one, default_constructed); + + StatusOr computation = builder.Build(); + EXPECT_FALSE(computation.ok()); + LOG(INFO) << "status received: " << computation.status(); + EXPECT_MATCH(computation.status().error_message(), + testing::HasSubstr("shape has invalid")); +} + +TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { + ComputationBuilder builder(client_, TestName()); + auto zero = builder.ConstantR0(0.0); + auto one = builder.ConstantR0(1.0); + Shape sans_layout; + sans_layout.set_element_type(F32); + sans_layout.add_dimensions(1); + + builder.RngUniform(zero, one, sans_layout); + + StatusOr computation = builder.Build(); + ASSERT_TRUE(computation.ok()); + LOG(INFO) << computation.status(); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc new file mode 100644 index 0000000000..598fd69909 --- /dev/null +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class BatchNormalizationTest : public ClientLibraryTestBase { + protected: + BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { + Array2D pz({ + // z0 z1 + {-1.0f, 4.1f}, // p0 + {2.0f, 4.1f}, // p1 + {5.0f, 4.4f}, // p2 + }); + input_array_.FillWithPZ(pz); + input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_); + CHECK_EQ(kSamples, input_array_.planes()); + CHECK_EQ(kZ, input_array_.depth()); + CHECK_EQ(kY, input_array_.height()); + CHECK_EQ(kY, input_array_.width()); + } + + static constexpr int64 kSamples = 3; + static constexpr int64 kX = 1; + static constexpr int64 kY = 1; + static constexpr int64 kZ = 2; + + Array4D input_array_; + Literal input_literal_; + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +TEST_F(BatchNormalizationTest, SubtractInZ) { + ComputationBuilder builder(client_, "subtract_in_z_one_sample"); + auto x = builder.ConstantLiteral(input_literal_); + auto y = builder.ConstantR1({3.14, 4.25}); + builder.Sub(x, y, /*broadcast_dimensions=*/{1}); + + Array4D expected(kSamples, kZ, kY, kX); + Array2D pz({ + {-1.0f - 3.14f, 4.1f - 4.25f}, // p0 + {2.0f - 3.14f, 4.1f - 4.25f}, // p1 + {5.0f - 3.14f, 4.4f - 4.25f}, // p2 + }); + expected.FillWithPZ(pz); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, SquareTesseractElementwise) { + ComputationBuilder builder(client_, "square_tesseract_elementwise"); + auto x = builder.ConstantLiteral(input_literal_); + builder.SquareF32(x); + + Array4D expected(kSamples, kZ, kY, kX); + Array2D expected_pz({ + {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)}, + {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)}, + {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)}, + }); + expected.FillWithPZ(expected_pz); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, SumToZ) { + ComputationBuilder builder(client_, "sum_to_z"); + auto input_activations = builder.ConstantLiteral(input_literal_); + Computation add = CreateScalarAddComputation(F32, &builder); + // Reduce all but the Z dimension. + builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, + {0, 2, 3}); + + std::vector expected = {6, 12.6}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, SquareAndReduce) { + ComputationBuilder builder(client_, "square_and_reduce"); + auto input_activations = builder.ConstantLiteral(input_literal_); + auto set_means = builder.ConstantR1({2.f, 4.2f}); + auto activation_deviations = builder.Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + Computation add = CreateScalarAddComputation(F32, &builder); + auto dev_squares = builder.SquareF32(activation_deviations); + auto sum_of_squares = builder.Reduce( + dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); + + std::vector expected = {18, 0.06}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(BatchNormalizationTest, VarianceToStddev) { + ComputationBuilder builder(client_, "variance_to_stddev"); + auto variance = builder.ConstantR1({6.f, .02f}); + auto sqrt = builder.SqrtF32(variance); + + std::vector expected = {2.44948974f, 0.14142136f}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +// Compare against a forward batch normalization example in the NN spec +// reference. +TEST_F(BatchNormalizationTest, SpecComparisonForward) { + ComputationBuilder builder(client_, "batch_normalize_per_spec"); + auto input_activations = + builder.CheckShape(builder.ConstantLiteral(input_literal_), + ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); + auto gamma = builder.ConstantR1({1.0, 1.0}); + auto beta = builder.ConstantR1({0.0, 0.0}); + Computation add = CreateScalarAddComputation(F32, &builder); + // Reduce all dimensions except dimension 1. + Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); + auto sum = builder.CheckShape( + builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); + auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); + auto count = builder.ConstantR0(ShapeUtil::ElementsIn(*input_shape) / + ShapeUtil::ElementsIn(*sum_shape)); + auto set_means = builder.Div(sum, count); + + const float kEpsilon = 1e-9f; + auto epsilon = builder.ConstantR0(kEpsilon); + auto epsilon2 = builder.ConstantR1({kEpsilon, kEpsilon}); + auto activation_deviations = builder.Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + auto dev_squares = builder.SquareF32(activation_deviations); + auto sum_of_squares = builder.CheckShape( + builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto variance = builder.Div(sum_of_squares, count); + auto standard_deviation = builder.SqrtF32(variance); + auto standard_deviation_above_epsilon = builder.CheckShape( + builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); + auto gt_eps = builder.Select(standard_deviation_above_epsilon, + standard_deviation, epsilon2); + auto normalization_factors = builder.ReciprocalF32(gt_eps); + auto normalized_input_activations = + builder.Mul(activation_deviations, normalization_factors, + /*broadcast_dimensions=*/{1}); + /* auto output_activations = */ builder.Add( + builder.Mul(normalized_input_activations, gamma, + /*broadcast_dimensions=*/{1}), + beta, /*broadcast_dimensions=*/{1}); + + Array4D expected(kSamples, kZ, kY, kX); + Array2D pz({ + {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)}, + {0.f, -.1f / std::sqrt(.02f)}, + {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)}, + }); + expected.FillWithPZ(pz); + + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc new file mode 100644 index 0000000000..e825bd435b --- /dev/null +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class BinopScalingTest : public ClientLibraryTestBase {}; + +TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { + auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4); + auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + builder.Add(lhs, rhs); + + auto aexpected = ReferenceUtil::MapWithIndexArray2D( + *alhs, [&](float lhs_value, int64 row, int64 col) { + return lhs_value + (*arhs)(0, col); + }); + ComputeAndCompareR2(&builder, *aexpected, {}, ErrorSpec(0.0001)); +} + +TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) { + auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 129); + auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + builder.Add(lhs, rhs); + + auto aexpected = ReferenceUtil::MapWithIndexArray2D( + *alhs, [&](float lhs_value, int64 row, int64 col) { + return lhs_value + (*arhs)(0, col); + }); + ComputeAndCompareR2(&builder, *aexpected, {}, ErrorSpec(0.0001)); +} + +TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) { + auto alhs = MakeLinspaceArray2D(0.0, 1.0, 9, 5); + auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + builder.Add(lhs, rhs); + + auto aexpected = ReferenceUtil::MapWithIndexArray2D( + *alhs, [&](float lhs_value, int64 row, int64 col) { + return lhs_value + (*arhs)(row, 0); + }); + ComputeAndCompareR2(&builder, *aexpected, {}, ErrorSpec(0.0001)); +} + +TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { + auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 257); + auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + builder.Add(lhs, rhs); + + auto aexpected = ReferenceUtil::MapWithIndexArray2D( + *alhs, [&](float lhs_value, int64 row, int64 col) { + return lhs_value + (*arhs)(row, 0); + }); + ComputeAndCompareR2(&builder, *aexpected, {}, ErrorSpec(0.0001)); +} + +TEST_F(BinopScalingTest, R0PlusR2F32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR0(42.0); + auto rhs = builder.ConstantR2({ + {1.0, 2.0}, {3.0, 4.0}, + }); + builder.Add(lhs, rhs); + + Array2D expected(2, 2); + expected(0, 0) = 42.0 + 1.0; + expected(0, 1) = 42.0 + 2.0; + expected(1, 0) = 42.0 + 3.0; + expected(1, 1) = 42.0 + 4.0; + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(BinopScalingTest, R4PlusR0S32) { + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array4D lhs_array({ + {{{1, 2}, + {3, 4}, + {5, 6}}}, + {{{7, 8}, + {9, 10}, + {11, 12}}}, + }); + Array4D expected({ + {{{43, 44}, + {45, 46}, + {47, 48}}}, + {{{49, 50}, + {51, 52}, + {53, 54}}}, + }); + // clang-format on + + auto lhs = builder.ConstantR4FromArray4D(lhs_array); + auto rhs = builder.ConstantR0(42); + builder.Add(lhs, rhs); + ComputeAndCompareR4(&builder, expected, {}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc new file mode 100644 index 0000000000..200d4d4563 --- /dev/null +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +using BroadcastSimpleTest = ClientLibraryTestBase; + +XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR0(1.5), {}); + ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR0(2.25), {2, 3}); + Array2D expected(2, 3, 2.25); + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR0(2.25), {2, 0}); + Array2D expected(2, 0); + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR0(2.25), {0, 2}); + Array2D expected(0, 2); + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR1({1, 2, 3}), {2}); + + Array2D expected(2, 3); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(0, 2) = 3; + expected(1, 0) = 1; + expected(1, 1) = 2; + expected(1, 2) = 3; + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR1({}), {2}); + + Array2D expected(2, 0); + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { + ComputationBuilder b(client_, TestName()); + b.Broadcast(b.ConstantR1({1, 2, 3}), {0}); + + Array2D expected(0, 3); + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { + // Verify that binary op and degenerate dimension broadcast work together in + // the same operation. + // + // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension + // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape + // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one + // dimensions. + ComputationBuilder b(client_, TestName()); + + b.Add(b.ConstantR2({{1.0, 5.0}}), + b.ConstantLiteral(*LiteralUtil::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); + + auto expected = + LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { + // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) + // results in a shape incompatible with the lhs [2, 3, 1]. + ComputationBuilder b(client_, TestName()); + + b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), + b.ConstantLiteral(*LiteralUtil::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); + + auto result_status = Execute(&b, {}); + EXPECT_FALSE(result_status.ok()); + EXPECT_MATCH(result_status.status().error_message(), + testing::ContainsRegex("broadcast dimension 0 mismatch")); +} + +XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { + // Test invalid broadcasting with [1, 2] and [2, 3] inputs. + ComputationBuilder b(client_, TestName()); + + b.Add(b.ConstantR2({{1.0, 2.0}}), + b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + + auto result_status = Execute(&b, {}); + EXPECT_FALSE(result_status.ok()); + EXPECT_MATCH( + result_status.status().error_message(), + testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes")); +} + +XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { + // Test invalid broadcasting with [1, 2] and [2, 3] inputs. + ComputationBuilder b(client_, TestName()); + + b.Add(b.ConstantR2({{1.0, 2.0}}), + b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + + auto result_status = Execute(&b, {}); + EXPECT_FALSE(result_status.ok()); + EXPECT_MATCH( + result_status.status().error_message(), + testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes")); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc new file mode 100644 index 0000000000..1796a732e5 --- /dev/null +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -0,0 +1,286 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class BroadcastTest : public HloTestBase {}; + +XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { + // Test degenerate case of broadcasting a scalar into a scalar. + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {}), input, {})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(42.0), *result, + error_spec_); +} + +XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2}), input, {})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + error_spec_); +} + +XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + + // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple + // to enable testing of the results. + auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 2}), input, {0})); + auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 3}), input, {1})); + builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + result->tuple_literals(0), error_spec_); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + result->tuple_literals(1), error_spec_); +} + +XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + error_spec_); +} + +XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { + // Degenerately broadcasting a shape into a shape of the same rank reorders + // the dimensions, ie transpose. + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + error_spec_); +} + +XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *result, error_spec_); +} + +TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); + + // Broadcast vector in dimension 1. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + Array4D expected(2, 2, 3, 3); + Array2D pz({{1, 2}, {1, 2}}); + expected.FillWithPZ(pz); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + +TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { + auto builder = HloComputation::Builder(TestName()); + std::vector input_data(1025); + int64 r1_size = input_data.size(); + std::iota(input_data.begin(), input_data.end(), 0.0f); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); + + // Broadcast vector in dimension 3. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + Array4D expected(3, 3, 3, 1025); + Array2D yx(/*height=*/3, /*width=*/r1_size); + for (int64 y = 0; y < 3; ++y) { + for (int64 x = 0; x < r1_size; ++x) { + yx(y, x) = input_data[x]; + } + } + expected.FillWithYX(yx); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + +XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { + auto builder = HloComputation::Builder(TestName()); + Array4D r4_array(32, 64, 7, 7); + r4_array.Fill(42.0); + std::vector r1_array(64, 42.0); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); + + // Broadcast vector in dimension 1. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR4FromArray4D(r4_array), + *result, error_spec_); +} + +TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + LOG(INFO) << hlo_module->ToString(); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + Array4D expected(64, 64, 3, 3); + expected.Fill(1.0f); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + +TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { + auto builder = HloComputation::Builder(TestName()); + Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(to_broadcast))); + + // Broadcast vector in dimensions 2 and 3. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + Array4D expected(3, 3, 2, 2); + expected.FillWithYX(to_broadcast); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl new file mode 100644 index 0000000000..2c7eeb820d --- /dev/null +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -0,0 +1,149 @@ +"""Build rules for XLA testing.""" + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") + +def all_backends(): + if cuda_is_configured(): + return ["cpu", "cpu_parallel", "gpu"] + else: + return ["cpu", "cpu_parallel"] + +def xla_test(name, + srcs, + deps, + backends=[], + args=[], + tags=[], + copts=[], + backend_tags={}, + backend_args={}, + **kwargs): + """Generates cc_test targets for the given XLA backends. + + This rule generates a cc_test target for one or more XLA backends and also + a platform-agnostic cc_library rule. The arguments are identical to cc_test + with two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "cpu_parallel", "gpu"), and + 'backend_args'/'backend_tags' specifies backend-specific args parameters to + use when generating the cc_test. + + The name of the cc_tests are the provided name argument with the backend name + appended, and the cc_library target name is the provided name argument with + "_lib" appended. For example, if name parameter is "foo_test", then the cpu + test target will be "foo_test_cpu" and the cc_library target is "foo_lib". + + The cc_library target can be used to link with other plugins outside of + xla_test. + + The build rule also defines a test suite ${name} which includes the tests for + each of the supported backends. + + Each generated cc_test target has a tag indicating which backend the test is + for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These + tags can be used to gather tests for a particular backend into a test_suite. + + Examples: + + # Generates the targets: foo_test_cpu and foo_test_gpu. + xla_test( + name = "foo_test", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + + # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu + # includes the additional arg "--special_cpu_flag". + xla_test( + name = "bar_test", + srcs = ["bar_test.cc"], + backends = ["cpu", "gpu"], + backend_args = {"cpu": ["--special_cpu_flag"]} + deps = [...], + ) + + The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} + to the value 1 where ${BACKEND} is the uppercase name of the backend. + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + backends: A list of backends to generate tests for. Supported + values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will + be generated for all supported backends. + args: Test arguments for the target. + tags: Tags for the target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + backend_tags: A dict mapping backend name to list of additional tags to + use for that target. + """ + test_names = [] + if not backends: + backends = all_backends() + + native.cc_library( + name="%s_lib" % name, + srcs=srcs, + copts=copts, + testonly=True, + deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], + ) + + for backend in backends: + test_name = "%s_%s" % (name, backend) + this_backend_tags = ["xla_%s" % backend] + this_backend_copts = [] + this_backend_args = backend_args.get(backend, []) + if backend == "cpu": + backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] + elif backend == "cpu_parallel": + backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] + this_backend_args += ["--xla_cpu_parallel=true"] + elif backend == "gpu": + backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] + this_backend_tags += ["requires-gpu-sm35"] + else: + fail("Unknown backend %s" % backend) + + native.cc_test( + name=test_name, + srcs=srcs, + tags=tags + backend_tags.get(backend, []) + this_backend_tags, + copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + args=args + this_backend_args, + deps=deps + backend_deps, + **kwargs) + + test_names.append(test_name) + + native.test_suite(name=name, tests=test_names) + + +def generate_backend_suites(backends=[]): + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name="%s_tests" % backend, + tags = ["xla_%s" % backend]) + + +def generate_backend_test_macros(backends=[]): + if not backends: + backends = all_backends() + for backend in backends: + native.cc_library( + name="test_macros_%s" % backend, + testonly = True, + hdrs = ["test_macros.h"], + copts = ["-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper()], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ]) diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc new file mode 100644 index 0000000000..1c96b73034 --- /dev/null +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class CallOpTest : public ClientLibraryTestBase { + protected: + Computation CreateR0F32IdentityComputation() { + ComputationBuilder builder(client_, "Identity"); + builder.Parameter(0, r0f32_, "x"); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR1S0F32AdditionComputation() { + ComputationBuilder builder(client_, "Addition"); + auto x = builder.Parameter(0, r1s0f32_, "x"); + auto y = builder.Parameter(1, r1s0f32_, "y"); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR1S2F32AdditionComputation() { + ComputationBuilder builder(client_, "Addition"); + auto x = builder.Parameter(0, r1s2f32_, "x"); + auto y = builder.Parameter(1, r1s2f32_, "y"); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape r1s0f32_ = ShapeUtil::MakeShape(F32, {0}); + Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2}); +}; + +XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { + ComputationBuilder builder(client_, TestName()); + Computation callee = CreateR0F32IdentityComputation(); + auto constant = builder.ConstantLiteral(*LiteralUtil::CreateR0(42.0)); + builder.Call(callee, {constant}); + + ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); +} + +XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { + ComputationBuilder builder(client_, TestName()); + Computation callee = CreateR1S0F32AdditionComputation(); + auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); + auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); + builder.Call(callee, {x, y}); + + ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); +} + +XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { + ComputationBuilder builder(client_, TestName()); + Computation callee = CreateR1S2F32AdditionComputation(); + auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({1.0f, 2.0f})); + auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({2.0f, 3.0f})); + builder.Call(callee, {x, y}); + + ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc new file mode 100644 index 0000000000..675c9fccb0 --- /dev/null +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class CheckExecutionArityTest : public ClientLibraryTestBase {}; + +TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { + ComputationBuilder builder(client_, "add_two_params"); + auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); + + auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); + auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); + auto add = builder.Add(p0, p1); + + auto param0_data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + auto param1_data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + + auto computation_status = builder.Build(); + ASSERT_IS_OK(computation_status.status()); + auto computation = computation_status.ConsumeValueOrDie(); + + // The arity of the UserComputation is 2 arguments. Execution will succeed + // with 2 arguments, but fail with a different number. + auto result_two_args = + client_->Execute(computation, {param0_data.get(), param1_data.get()}); + ASSERT_IS_OK(result_two_args.status()); + + auto result_one_arg = client_->Execute(computation, {param0_data.get()}); + ASSERT_FALSE(result_one_arg.ok()); + ASSERT_EQ(result_one_arg.status().code(), + tensorflow::error::INVALID_ARGUMENT); + ASSERT_MATCH(result_one_arg.status().error_message(), + testing::ContainsRegex("takes 2")); + + auto result_zero_args = client_->Execute(computation, {}); + ASSERT_FALSE(result_zero_args.ok()); + ASSERT_EQ(result_zero_args.status().code(), + tensorflow::error::INVALID_ARGUMENT); + ASSERT_MATCH(result_zero_args.status().error_message(), + testing::ContainsRegex("takes 2")); +} + +XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { + ComputationBuilder builder(client_, "add_two_params"); + + auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); + auto add = builder.Mul(p0, p1); + + auto computation_status = builder.Build(); + ASSERT_IS_OK(computation_status.status()); + auto computation = computation_status.ConsumeValueOrDie(); + + auto f32_literal = LiteralUtil::CreateR0(1.1f); + auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); + auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_data = + client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); + auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); + auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); + + // Match + auto status = + client_->Execute(computation, {f32_data.get(), f32_4_data.get()}); + ASSERT_IS_OK(status.status()); + + // Shape mismatch in parameter 0 + status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); + ASSERT_MATCH(status.status().error_message(), + testing::ContainsRegex("expects parameter 0")); + + // Shape mismatch in parameter 1 (rank) + status = client_->Execute(computation, {f32_data.get(), f32_data.get()}); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); + ASSERT_MATCH(status.status().error_message(), + testing::ContainsRegex("expects parameter 1")); + + // Shape mismatch in parameter 1 (element type) + status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); + ASSERT_MATCH(status.status().error_message(), + testing::ContainsRegex("expects parameter 1")); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc new file mode 100644 index 0000000000..d2a7def5d0 --- /dev/null +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -0,0 +1,263 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" + +#include + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace { +// Wrapper function that creates a nicer error message (than a bare +// ValueOrDie()) if the platform we intend to test is not available. +Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { + StatusOr result = ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) << "could not create local client for testing"; + return result.ValueOrDie(); +} +} // namespace + +ClientLibraryTestBase::ClientLibraryTestBase( + se::Platform* platform, + tensorflow::gtl::ArraySlice disabled_pass_names) + : client_(GetOrCreateLocalClientOrDie(platform)) { + legacy_flags::HloPassPipelineFlags* flags = + legacy_flags::GetHloPassPipelineFlags(); + flags->xla_disable_hlo_passes = + tensorflow::str_util::Join(disabled_pass_names, ","); +} + +string ClientLibraryTestBase::TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); +} + +StatusOr> ClientLibraryTestBase::Execute( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return client_->Execute(computation, arguments); +} + +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return client_->ExecuteAndTransfer(computation, arguments, + shape_with_output_layout); +} + +std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments) { + return Execute(builder, arguments).ConsumeValueOrDie(); +} + +std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments) { + return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); +} + +string ClientLibraryTestBase::ExecuteToString( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments) { + StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status().ToString(); + } + Computation computation = computation_status.ConsumeValueOrDie(); + + auto result = client_->ExecuteAndTransfer(computation, arguments); + if (!result.ok()) { + return result.status().ToString(); + } else { + return LiteralUtil::ToString(*result.ValueOrDie()); + } +} + +void ClientLibraryTestBase::ComputeAndCompareR1( + ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, + tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments); +} + +void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout) { + EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, + shape_with_layout)); +} + +void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout) { + EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, + error, shape_with_layout)); +} + +tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout) { + TF_ASSIGN_OR_RETURN( + auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout)); + if (ShapeUtil::ElementIsFloating(expected.shape())) { + LOG(WARNING) << "performing exact comparison of floating point numbers"; + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || + expected.shape().element_type() == PRED); + } + LiteralTestUtil::ExpectEqual(expected, *actual); + return tensorflow::Status::OK(); +} + +tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout) { + TF_ASSIGN_OR_RETURN( + auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout)); + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape())); + LiteralTestUtil::ExpectNear(expected, *actual, error); + return tensorflow::Status::OK(); +} + +void ClientLibraryTestBase::ComputeAndCompareR1U8( + ComputationBuilder* builder, tensorflow::StringPiece expected, + tensorflow::gtl::ArraySlice arguments) { + auto actual_status = ExecuteAndTransfer(builder, arguments); + EXPECT_IS_OK(actual_status.status()); + if (!actual_status.ok()) { + return; + } + auto actual = actual_status.ConsumeValueOrDie(); + + // Turn the expected value into a literal. + std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + + VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); + VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); + + EXPECT_EQ(expected, actual->u8s()); +} + +void ClientLibraryTestBase::ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments) { + auto actual_status = ExecuteAndTransfer(builder, arguments); + EXPECT_IS_OK(actual_status.status()); + if (!actual_status.ok()) { + return; + } + auto actual = actual_status.ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqualTuple(expected, *actual); +} + +void ClientLibraryTestBase::ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + auto actual_status = ExecuteAndTransfer(builder, arguments); + EXPECT_IS_OK(actual_status.status()); + if (!actual_status.ok()) { + return; + } + auto actual = actual_status.ConsumeValueOrDie(); + LiteralTestUtil::ExpectNearTuple(expected, *actual, error); +} + +Computation ClientLibraryTestBase::CreateScalarRelu() { + ComputationBuilder builder(client_, "relu"); + auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); + auto zero = builder.ConstantR0(0.0); + builder.Max(z_value, zero); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); +} + +Computation ClientLibraryTestBase::CreateScalarMax() { + ComputationBuilder builder(client_, "max"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Max(x, y); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); +} + +Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { + ComputationBuilder builder(client_, "relu_sensitivity"); + auto activation = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation"); + auto backprop = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop"); + auto zero = builder.ConstantR0(0.0); + auto activation_gtz = builder.Gt(activation, zero); + builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); + + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); +} + +std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( + int rows, int cols, float offset) { + auto array = MakeUnique>(rows, cols); + for (int64 row = 0; row < rows; ++row) { + for (int64 col = 0; col < cols; ++col) { + (*array)(row, col) = col + (row * 1000.0f) + offset; + } + } + return array; +} + +std::unique_ptr> +ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, + int rows_padded, + int cols_padded) { + CHECK_GE(rows_padded, rows); + CHECK_GE(cols_padded, cols); + auto array = MakeUnique>(rows_padded, cols_padded, 0.0); + for (int64 row = 0; row < rows; ++row) { + for (int64 col = 0; col < cols; ++col) { + (*array)(row, col) = col + (row * 1000.0f); + } + } + return array; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h new file mode 100644 index 0000000000..690fda3ffa --- /dev/null +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -0,0 +1,409 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A client library test establishes an in-process XLA client connection. +class ClientLibraryTestBase : public ::testing::Test { + protected: + explicit ClientLibraryTestBase( + perftools::gputools::Platform* platform = nullptr, + tensorflow::gtl::ArraySlice disabled_pass_names = {}); + + // Returns the name of the test currently being run. + string TestName() const; + + // TODO(b/25566808): Add helper that populates a literal from a testdata file. + + // Convenience methods for building and running a computation from a builder. + StatusOr> Execute( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + StatusOr> ExecuteAndTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); + + // Convenience OrDie variants of above methods. + std::unique_ptr ExecuteOrDie( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteAndTransferOrDie( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + + // Run a computation and return its value as a string. If an error + // occurs, then instead return the error as a string. + string ExecuteToString(ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + + // Convenience methods for building and running a computation, transferring + // the result, and comparing it to the expected value(s). Methods are + // templated on the native host type which maps to specific XLA types (See + // ComputationBuilder for details). For each rank, two forms are provided: one + // for floating point types with an ErrorSpec parameter, and one for integral + // types without the ErrorSpec parameter. + template + void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + tensorflow::gtl::ArraySlice arguments); + template + void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + + template + void ComputeAndCompareR1(ComputationBuilder* builder, + tensorflow::gtl::ArraySlice expected, + tensorflow::gtl::ArraySlice arguments); + template + void ComputeAndCompareR1(ComputationBuilder* builder, + tensorflow::gtl::ArraySlice expected, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + + // As above, but uses a bitmap to hold the predicate vector to avoid + // deficiencies of vector. + void ComputeAndCompareR1(ComputationBuilder* builder, + const tensorflow::core::Bitmap& expected, + tensorflow::gtl::ArraySlice arguments); + + template + void ComputeAndCompareR2(ComputationBuilder* builder, + const Array2D& expected, + tensorflow::gtl::ArraySlice arguments); + template + void ComputeAndCompareR2(ComputationBuilder* builder, + const Array2D& expected, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + + template + void ComputeAndCompareR3(ComputationBuilder* builder, + const Array3D& expected, + tensorflow::gtl::ArraySlice arguments); + template + void ComputeAndCompareR3(ComputationBuilder* builder, + const Array3D& expected, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + + template + void ComputeAndCompareR4(ComputationBuilder* builder, + const Array4D& expected, + tensorflow::gtl::ArraySlice arguments); + template + void ComputeAndCompareR4(ComputationBuilder* builder, + const Array4D& expected, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + + // Build and run the computation and compare the result with the given + // literal. shape_with_layout indicates the result layout to request when + // calling Execute. + void ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout = nullptr); + + // ComputeAndCompare variant which returns an error status. + tensorflow::Status ComputeAndCompareLiteralWithStatus( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout = nullptr); + tensorflow::Status ComputeAndCompareLiteralWithStatus( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout = nullptr); + + // Compare the result of the computation to a strings. In XLA strings are + // represented using rank-1 U8 shapes. + void ComputeAndCompareR1U8( + ComputationBuilder* builder, tensorflow::StringPiece expected, + tensorflow::gtl::ArraySlice arguments); + + // Convenience method for running a built computation, transferring the + // result, and comparing it to the expected tuple literal. + void ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + + // Create scalar operations for use in reductions. + Computation CreateScalarRelu(); + Computation CreateScalarMax(); + Computation CreateScalarReluSensitivity(); + + // Special case convenience functions for creating filled arrays. + + // Creates an array of pseudorandom values lying between the given minimum and + // maximum values. + template + std::vector CreatePseudorandomR1(const int width, NativeT min_value, + NativeT max_value, uint32 seed); + template + std::unique_ptr> CreatePseudorandomR2(const int rows, + const int cols, + NativeT min_value, + NativeT max_value, + uint32 seed); + + // Creates a (rows x cols) array filled in the following form: + // + // [ 0 1 ... cols-1] + // [ 1,000 1,001 ... 1000.0 + cols-1] + // [ ... ... ... ...] + // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1] + // + // If provided, offset is added uniformly to every element (e.g. an offset of + // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.) + std::unique_ptr> CreatePatternedMatrix(const int rows, + const int cols, + float offset = 0.0); + + // Creates a (rows x cols) array as above, padded out to + // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows + // and cols_padded > cols. + std::unique_ptr> CreatePatternedMatrixWithZeroPadding( + const int rows, const int cols, const int rows_padded, + const int cols_padded); + + // Create a parameter instruction that wraps the given values and then stores + // into "data_handle" the global handle for that parameter. + // + // "parameter_number" is the parameter number. + // "name" is the name of the parameter instruction. + template + std::unique_ptr CreateR1Parameter( + tensorflow::gtl::ArraySlice values, int64 parameter_number, + const string& name, ComputationBuilder* builder, + ComputationDataHandle* data_handle); + + // Create a parameter instruction that wraps the given constant array + // "array_2d" and then stores to "data_handle" the global handle for that + // parameter. + // + // "parameter_number" is the parameter number. + // "name" is the name of the parameter instruction. + template + std::unique_ptr CreateR2Parameter( + const Array2D& array_2d, int64 parameter_number, + const string& name, ComputationBuilder* builder, + ComputationDataHandle* data_handle); + + Client* client_; +}; + +template +void ClientLibraryTestBase::ComputeAndCompareR0( + ComputationBuilder* builder, NativeT expected, + tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr expected_literal = + LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR0( + ComputationBuilder* builder, NativeT expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + static_assert(std::is_same::value || + std::is_same::value, + "Floating point type required when specifying an ErrorSpec"); + std::unique_ptr expected_literal = + LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments, error); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR1( + ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr expected_literal = + LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR1( + ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + static_assert(std::is_same::value || + std::is_same::value, + "Floating point type required when specifying an ErrorSpec"); + std::unique_ptr expected_literal = + LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments, error); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR2( + ComputationBuilder* builder, const Array2D& expected, + tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr expected_literal = + LiteralUtil::CreateR2FromArray2D(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR2( + ComputationBuilder* builder, const Array2D& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + static_assert(std::is_same::value || + std::is_same::value, + "Floating point type required when specifying an ErrorSpec"); + std::unique_ptr expected_literal = + LiteralUtil::CreateR2FromArray2D(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments, error); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR3( + ComputationBuilder* builder, const Array3D& expected, + tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr expected_literal = + LiteralUtil::CreateR3FromArray3D(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR3( + ComputationBuilder* builder, const Array3D& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + static_assert(std::is_same::value || + std::is_same::value, + "Floating point type required when specifying an ErrorSpec"); + std::unique_ptr expected_literal = + LiteralUtil::CreateR3FromArray3D(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments, error); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR4( + ComputationBuilder* builder, const Array4D& expected, + tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr expected_literal = + LiteralUtil::CreateR4FromArray4D(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments); +} + +template +void ClientLibraryTestBase::ComputeAndCompareR4( + ComputationBuilder* builder, const Array4D& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + static_assert(std::is_same::value || + std::is_same::value, + "Floating point type required when specifying an ErrorSpec"); + std::unique_ptr expected_literal = + LiteralUtil::CreateR4FromArray4D(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + arguments, error); +} + +template +std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( + tensorflow::gtl::ArraySlice values, int64 parameter_number, + const string& name, ComputationBuilder* builder, + ComputationDataHandle* data_handle) { + std::unique_ptr literal = LiteralUtil::CreateR1(values); + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + return data; +} + +template +std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( + const Array2D& array_2d, int64 parameter_number, + const string& name, ComputationBuilder* builder, + ComputationDataHandle* data_handle) { + std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + return data; +} + +template +std::vector ClientLibraryTestBase::CreatePseudorandomR1( + const int width, NativeT min_value, NativeT max_value, uint32 seed) { + std::vector result(width); + test_utils::PseudorandomGenerator generator(min_value, max_value, + seed); + for (int i = 0; i < width; ++i) { + result[i] = generator.get(); + } + return result; +} + +template +std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( + const int rows, const int cols, NativeT min_value, NativeT max_value, + uint32 seed) { + auto result = MakeUnique>(rows, cols); + test_utils::PseudorandomGenerator generator(min_value, max_value, + seed); + for (int y = 0; y < rows; ++y) { + for (int x = 0; x < cols; ++x) { + (*result)(y, x) = generator.get(); + } + } + return result; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc new file mode 100644 index 0000000000..77b85af83c --- /dev/null +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ClientTest : public ClientLibraryTestBase {}; + +TEST_F(ClientTest, ExecuteWithLayout) { + ComputationBuilder b(client_, TestName()); + + std::vector> layouts = {{0, 1}, {1, 0}}; + for (const std::vector& execute_layout : layouts) { + for (const std::vector& transfer_layout : layouts) { + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})); + auto computation = b.Build(); + ASSERT_TRUE(computation.ok()) << computation.status(); + + const Shape execute_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + S32, /*dimensions=*/{2, 2}, execute_layout); + std::unique_ptr data = + client_ + ->Execute(computation.ValueOrDie(), {}, + &execute_shape_with_layout) + .ConsumeValueOrDie(); + + std::unique_ptr expected_literal = + test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, + transfer_layout); + + auto computed = client_->Transfer(*data, &expected_literal->shape()); + + LiteralTestUtil::AssertEqualShapesAndLayouts( + expected_literal->shape(), computed.ValueOrDie()->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } + } +} + +TEST_F(ClientTest, ExecuteWithTupleLayout) { + ComputationBuilder b(client_, TestName()); + + b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})}); + + auto computation = b.Build(); + ASSERT_TRUE(computation.ok()) << computation.status(); + + // Create a result shape with one element column major and the other row + // major. + Shape tuple_shape_with_layout = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, + /*minor_to_major=*/{0, 1}), + ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, + /*minor_to_major=*/{1, 0})}); + + auto result = client_ + ->ExecuteAndTransfer(computation.ValueOrDie(), {}, + &tuple_shape_with_layout) + .ConsumeValueOrDie(); + LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, + result->tuple_literals(0)); + LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, + result->tuple_literals(1)); + + EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, + /*minor_to_major=*/{0, 1}))); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, + /*minor_to_major=*/{1, 0}))); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc new file mode 100644 index 0000000000..fe4dff2109 --- /dev/null +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/codegen_test_base.h" + +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, + const string& pattern) { + std::unique_ptr executable = + CompileToExecutable(std::move(hlo_module)); + string ir_module_string = GetIrFromExecutable(*executable); + RunFileCheck(ir_module_string, pattern); +} + +std::unique_ptr CodegenTestBase::CompileToExecutable( + std::unique_ptr hlo_module) { + auto module_config = MakeUnique( + MakeProgramShape(hlo_module->entry_computation())); + return backend_->compiler() + ->Compile(std::move(hlo_module), std::move(module_config), + test_hlo_dumper_, backend_->default_stream_executor()) + .ConsumeValueOrDie(); +} + +void CodegenTestBase::RunFileCheck(const string& input, const string& pattern) { + // Write input to a temporary file. + char tempdir_template[] = "/tmp/ir_testXXXXXX"; + char* tempdir_name = mkdtemp(tempdir_template); + CHECK_NOTNULL(tempdir_name); + string pattern_path = + tensorflow::io::JoinPath(tempdir_name, "xla_hlo_test_ir_pattern"); + TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), + pattern_path, pattern)); + + // Invoke FileCheck to check whether input matches `pattern`. + tensorflow::SubProcess file_check_process; + const char* test_srcdir = getenv("TEST_SRCDIR"); + if (test_srcdir == nullptr) { + test_srcdir = "."; + } + string file_check_path = tensorflow::io::JoinPath( + test_srcdir, "external/llvm/FileCheck"); + file_check_process.SetProgram(file_check_path, + {file_check_path, pattern_path}); + file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, + tensorflow::ACTION_PIPE); + file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, + tensorflow::ACTION_PIPE); + CHECK(file_check_process.Start()); + string standard_error; + int exit_status = file_check_process.Communicate( + /*stdin_input=*/&input, /*stdout_output=*/nullptr, + /*stderr_output=*/&standard_error); + + // FileCheck returns 0 when the inputs match. If matching failed, we output + // the error message generated by FileCheck. + SCOPED_TRACE(tensorflow::strings::StrCat("Input to FileCheck:\n", input)); + EXPECT_EQ(0, exit_status) << standard_error; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h new file mode 100644 index 0000000000..50c0453107 --- /dev/null +++ b/tensorflow/compiler/xla/tests/codegen_test_base.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +// Tests that verify IR emitted by the CPU/GPU backend is as expected. +class CodegenTestBase : public HloTestBase { + protected: + CodegenTestBase() {} + + // Returns the embedded LLVM IR from the given executable. Codegen tests must + // override this method, but execution tests do not have to because they do + // not examine the embedded IR. + virtual string GetIrFromExecutable(const Executable& executable) = 0; + + // Compiles the given HLO module to LLVM IR and verifies the IR matches the + // given pattern. `pattern` is in the FileCheck pattern matching syntax + // (http://llvm.org/docs/CommandGuide/FileCheck.html). + void CompileAndVerifyIr(std::unique_ptr hlo_module, + const string& pattern); + + protected: + // Compiles hlo_module to an executable, CHECK-failing if this fails. + std::unique_ptr CompileToExecutable( + std::unique_ptr hlo_module); + + // Runs FileCheck with the given pattern over the given string and EXPECTs + // that FileCheck succeeded in matching the input. + void RunFileCheck(const string& input, const string& pattern); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc new file mode 100644 index 0000000000..38ce007cb0 --- /dev/null +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -0,0 +1,218 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class CompilationCacheTest : public ClientLibraryTestBase { + public: + void ExecuteComputationR0F32( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, float expected_result, + bool expect_cache_hit) { + ExecutionProfile execution_profile; + std::unique_ptr result = + client_ + ->ExecuteAndTransfer(computation, arguments, + /*output_layout=*/nullptr, &execution_profile) + .ConsumeValueOrDie(); + LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(expected_result), + *result, error_spec_); + EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); + } + + void ExecuteComputationR2F32( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + std::initializer_list> expected_result, + bool expect_cache_hit) { + ExecutionProfile execution_profile; + auto data_handle = + client_ + ->Execute(computation, arguments, /*output_layout=*/nullptr, + &execution_profile) + .ConsumeValueOrDie(); + std::unique_ptr result = + client_->Transfer(*data_handle).ConsumeValueOrDie(); + LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2(expected_result), + *result, error_spec_); + EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); + } + + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(42.0)); + Computation computation = builder.Build().ConsumeValueOrDie(); + + ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); + ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); +} + +XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { + std::unique_ptr data_42 = + client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + .ConsumeValueOrDie(); + std::unique_ptr data_123 = + client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + .ConsumeValueOrDie(); + std::unique_ptr data_456 = + client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Computation computation = builder.Build().ConsumeValueOrDie(); + + ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, + /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation, {data_123.get()}, -123.0, + /*expect_cache_hit=*/true); + ExecuteComputationR0F32(computation, {data_456.get()}, -456.0, + /*expect_cache_hit=*/true); + ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, + /*expect_cache_hit=*/true); +} + +XLA_TEST_F(CompilationCacheTest, MultipleComputations) { + ComputationBuilder builder_neg(client_, TestName() + "_neg"); + builder_neg.Neg(builder_neg.ConstantR0(42.0)); + Computation computation_neg = builder_neg.Build().ConsumeValueOrDie(); + + ComputationBuilder builder_exp(client_, TestName() + "_exp"); + builder_exp.Exp(builder_exp.ConstantR0(1.0)); + Computation computation_exp = builder_exp.Build().ConsumeValueOrDie(); + + ComputationBuilder builder_add(client_, TestName() + "_add"); + builder_add.Add(builder_add.ConstantR0(2.0), + builder_add.ConstantR0(3.0)); + Computation computation_add = builder_add.Build().ConsumeValueOrDie(); + + ExecuteComputationR0F32(computation_neg, {}, -42.0, + /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation_exp, {}, 2.7182817, + /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation_add, {}, 5.0, + /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation_neg, {}, -42.0, + /*expect_cache_hit=*/true); +} + +XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { + // Create two GlobalData arrays with the same shape but different + // layouts. Use these arrays as parameters to a simple computation. If the + // layout of the array changes then computation should be recompiled (cache + // miss). + auto rowmaj_array = test_utils::CreateR2LiteralWithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0}); + auto rowmaj_handle = + client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); + + auto colmaj_array = test_utils::CreateR2LiteralWithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1}); + auto colmaj_handle = + client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); + Computation computation = builder.Build().ConsumeValueOrDie(); + + ExecuteComputationR2F32(computation, {colmaj_handle.get()}, + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + /*expect_cache_hit=*/false); + ExecuteComputationR2F32(computation, {colmaj_handle.get()}, + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + /*expect_cache_hit=*/true); + ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + /*expect_cache_hit=*/false); + ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + /*expect_cache_hit=*/true); + ExecuteComputationR2F32(computation, {colmaj_handle.get()}, + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + /*expect_cache_hit=*/true); +} + +XLA_TEST_F(CompilationCacheTest, MutatedComputation) { + // Build a computation, execute it, then mutate it. The mutated computation + // should not be in the cache until it is run once. This must be done through + // the stub interface because Computations built from ComputationBuilder are + // immutable. + ComputationBuilder builder(client_, TestName()); + auto neg = builder.Neg(builder.ConstantR0(42.0)); + Computation computation = builder.Build().ConsumeValueOrDie(); + + ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); + + BinaryOpRequest request; + request.set_binop(BINOP_ADD); + *request.mutable_lhs() = neg; + *request.mutable_rhs() = neg; + OpRequest op_request; + *op_request.mutable_computation() = computation.handle(); + *op_request.mutable_binary_op_request() = request; + OpResponse response; + tensorflow::Status s = client_->stub()->Op(&op_request, &response); + ASSERT_TRUE(s.ok()); + + ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false); + ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc new file mode 100644 index 0000000000..709ce5029c --- /dev/null +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -0,0 +1,249 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ComputeConstantTest : public ClientLibraryTestBase { + public: + StatusOr> ComputeConstantLiteral( + ComputationDataHandle operand, ComputationBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto remote_computed, + builder->ComputeConstant(operand, output_layout)); + TF_ASSIGN_OR_RETURN(auto computed, client_->Transfer(*remote_computed)); + return std::move(computed); + } + + template + StatusOr ComputeConstantScalar(ComputationDataHandle operand, + ComputationBuilder* builder) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(operand, builder)); + return LiteralUtil::Get(*literal, {}); + } + + bool IsConstant(const ComputationDataHandle& operand, + ComputationBuilder* builder) { + StatusOr result = builder->IsConstant(operand); + EXPECT_TRUE(result.ok()) << result.status(); + return result.ok() ? result.ValueOrDie() : false; + } + + template + void ExpectConstantComputedScalar(ComputationDataHandle operand, + Scalar expected, + ComputationBuilder* builder) { + Scalar computed = ComputeConstantScalar(operand, builder); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(expected); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + } +}; + +TEST_F(ComputeConstantTest, ScalarInt32Literal) { + ComputationBuilder b(client_, TestName()); + auto computation = b.ConstantR0(42); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 42); +} + +TEST_F(ComputeConstantTest, ScalarFloatAdd) { + ComputationBuilder b(client_, TestName()); + auto computation = + b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); +} + +TEST_F(ComputeConstantTest, ScalarRng) { + ComputationBuilder b(client_, TestName()); + auto computation = + b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), + ShapeUtil::MakeShape(F32, {})); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(computation, &b); + ASSERT_FALSE(value.ok()) + << "computing a RNG value should not be considered a constant"; +} + +TEST_F(ComputeConstantTest, DirectParam) { + ComputationBuilder b(client_, TestName()); + auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); +} + +TEST_F(ComputeConstantTest, IndirectParam) { + ComputationBuilder b(client_, TestName()); + auto computation = + b.Add(b.ConstantR0(1.0f), + b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); +} + +// Test computation of an expression interspersed with param nodes but +// the expression does not depend on the param nodes. +TEST_F(ComputeConstantTest, UnrelatedParam) { + ComputationBuilder b(client_, TestName()); + + auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto constant_4 = b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); + auto not_constant_a = b.Add(constant_4, param_a); + + auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto constant_9 = b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); + auto not_constant_b = b.Add(param_b, constant_9); + + auto constant_13 = b.Add(constant_4, constant_9); + b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + + EXPECT_TRUE(IsConstant(constant_13, &b)); + + auto value = ComputeConstantScalar(constant_13, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 13.0f); +} + +TEST_F(ComputeConstantTest, NonScalarAdd) { + ComputationBuilder b(client_, TestName()); + + auto computation = + b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto computed = ComputeConstantLiteral(computation, &b); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = + LiteralUtil::CreateR1({4, 6}); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); +} + +TEST_F(ComputeConstantTest, IntegerDivide) { + ComputationBuilder b(client_, TestName()); + auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto computed = ComputeConstantLiteral(computation, &b); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); +} + +XLA_TEST_F(ComputeConstantTest, Layout) { + ComputationBuilder b(client_, TestName()); + + std::vector> layouts = {{0, 1}, {1, 0}}; + for (const std::vector& layout : layouts) { + auto layout_proto = LayoutUtil::MakeLayout(layout); + auto computed = + ComputeConstantLiteral(b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto); + ASSERT_TRUE(computed.ok()) << computed.status(); + + std::unique_ptr expected_literal = + test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, + layout); + LiteralTestUtil::AssertEqualShapesAndLayouts( + expected_literal->shape(), computed.ValueOrDie()->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } +} + +// This test is permanently disabled on CPU because it requires that the +// backend used for execution is different than the backend used for +// ComputeConstant which is always cpu. +TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { + // Compute a trivial constant, then try to use the value in an Execute + // call. This should fail because the constant resides on the CPU and the + // Execute call is executed on a different backend. + ComputationBuilder constant_b(client_, TestName()); + auto constant = constant_b.ConstantR0(42); + auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie(); + auto literal = client_->Transfer(*handle).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(42, *literal); + + // Build trivial computation which takes one parameter. + ComputationBuilder b(client_, TestName()); + b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0")); + auto computation = b.Build().ConsumeValueOrDie(); + + // Try to use value from ComputeConstant in Execute. + auto execute_status = client_->Execute(computation, {handle.get()}); + EXPECT_FALSE(execute_status.ok()); + EXPECT_MATCH( + execute_status.status().error_message(), + testing::ContainsRegex("argument 0 is on device Host:0 but computation " + "will be executed on device")); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc new file mode 100644 index 0000000000..9a48b19b96 --- /dev/null +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -0,0 +1,523 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +using ConcatTest = ClientLibraryTestBase; + +// Concatenate expects at least one argument. +XLA_TEST_F(ConcatTest, Concat_Nothing) { + ComputationBuilder builder(client_, TestName()); + auto concatenated = builder.ConcatInDim({}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_MATCH( + computation_status.status().ToString(), + testing::ContainsRegex("Concatenate expects at least one argument")); +} + +// Concatenate with one argument works. +XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0, 64.0}); + auto concatenated = builder.ConcatInDim({a}, 0); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Show that we can't concatenate R0 with R0 because we can't name the dimension +// to concatenate on. +XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR0(42.0); + auto b = builder.ConstantR0(64.0); + auto concatenated = builder.ConcatInDim({a, b}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_MATCH(computation_status.status().ToString(), + testing::ContainsRegex( + "dimension to concatenate along out of bounds: 0")); +} + +XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({256.0}); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + std::vector expected = {256}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0, 64.0}); + auto b = builder.ConstantR1({}); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0, 64.0}); + auto b = builder.ConstantR1({256.0}); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + std::vector expected = {42, 64, 256}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { + std::vector lhs(253); + std::vector rhs(7); + std::vector expected(253 + 7); + for (int i = 0; i < 253; ++i) { + expected[i] = lhs[i] = i + 1; + } + for (int i = 0; i < 7; ++i) { + expected[253 + i] = rhs[i] = 253 + i + 1; + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1(lhs); + auto b = builder.ConstantR1(rhs); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { + for (int dim : {0, 1}) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); + auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); + auto concatenated = builder.ConcatInDim({a, b}, dim); + + ComputeAndCompareR2(&builder, Array2D(0, 0), {}, + ErrorSpec(0.0001)); + } +} + +XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { + ComputationBuilder builder(client_, TestName()); + auto a_array = CreatePatternedMatrix(1, 1); + auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); + auto a = builder.ConstantR2FromArray2D(*a_array); + auto b = builder.ConstantR2FromArray2D(*b_array); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + Array2D expected({ + {0}, {64}, + }); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { + ComputationBuilder builder(client_, TestName()); + auto a_array = CreatePatternedMatrix(1, 1); + auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); + auto a = builder.ConstantR2FromArray2D(*a_array); + auto b = builder.ConstantR2FromArray2D(*b_array); + auto concatenated = builder.ConcatInDim({a, b}, 1); + + Array2D expected({ + {0, 64}, + }); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat2x0With2x5) { + ComputationBuilder builder(client_, TestName()); + auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); + auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto b = builder.ConstantR2FromArray2D(*b_array); + auto concatenated = builder.ConcatInDim({a, b}, 1); + + ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat2x3With2x5) { + ComputationBuilder builder(client_, TestName()); + auto a_array = CreatePatternedMatrix(2, 3); + auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); + auto a = builder.ConstantR2FromArray2D(*a_array); + auto b = builder.ConstantR2FromArray2D(*b_array); + auto concatenated = builder.ConcatInDim({a, b}, 1); + + Array2D expected({ + {0, 1, 2, 64, 65, 66, 67, 68}, + {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068}, + }); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat3x2With0x2) { + ComputationBuilder builder(client_, TestName()); + auto a_array = CreatePatternedMatrix(3, 2); + auto a = builder.ConstantR2FromArray2D(*a_array); + auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat3x2With5x2) { + ComputationBuilder builder(client_, TestName()); + auto a_array = CreatePatternedMatrix(3, 2); + auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); + auto a = builder.ConstantR2FromArray2D(*a_array); + auto b = builder.ConstantR2FromArray2D(*b_array); + auto concatenated = builder.ConcatInDim({a, b}, 0); + + Array2D expected({ + {0, 1}, + {1000, 1001}, + {2000, 2001}, + {64, 65}, + {1064, 1065}, + {2064, 2065}, + {3064, 3065}, + {4064, 4065}, + }); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); + auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); + auto concatenated = builder.ConcatInDim({a, b}, 2); + ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, + ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { + ComputationBuilder builder(client_, TestName()); + Array3D a_array({ + // 3x1x2 + {{0, 1}}, + {{2, 3}}, + {{4, 5}}, + }); + Array3D b_array({ + // 3x1x1 + {{6}}, + {{7}}, + {{8}}, + }); + auto a = builder.ConstantR3FromArray3D(a_array); + auto b = builder.ConstantR3FromArray3D(b_array); + auto concatenated = builder.ConcatInDim({a, b}, 2); + + Array3D expected({ + {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}}, + }); + ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0}); + auto b = builder.ConstantR1({64.0}); + auto c = builder.ConstantR1({256.0}); + auto concatenated = builder.ConcatInDim({a, b, c}, 0); + + std::vector expected = {42, 64, 256}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { + ComputationBuilder builder(client_, TestName()); + Array3D a_array({ + // 3x1x2 + {{0, 1}}, + {{4, 5}}, + {{8, 9}}, + }); + Array3D b_array({ + // 3x1x1 + {{2}}, + {{6}}, + {{10}}, + }); + Array3D c_array({ + // 3x1x1 + {{3}}, + {{7}}, + {{11}}, + }); + auto a = builder.ConstantR3FromArray3D(a_array); + auto b = builder.ConstantR3FromArray3D(b_array); + auto c = builder.ConstantR3FromArray3D(c_array); + auto concatenated = builder.ConcatInDim({a, b, c}, 2); + + Array3D expected({ + {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}}, + }); + ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0}); + auto b = builder.ConstantR1({64.0}); + auto c = builder.ConstantR1({256.0}); + // concatenated = (a concat b) concat c + auto concatenated = + builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + + std::vector expected = {42, 64, 256}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0}); + auto b = builder.ConstantR1({64.0}); + auto c = builder.ConstantR1({256.0}); + // concatenated = a concat (b concat c) + auto concatenated = + builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + + std::vector expected = {42, 64, 256}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { + Array2D lhs(1, 1024); + Array2D rhs(1, 1024); + for (int i = 0; i < 1024; ++i) { + lhs(0, i) = i; + rhs(0, i) = i + 1024; + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(lhs); + auto b = builder.ConstantR2FromArray2D(rhs); + builder.ConcatInDim({a, b}, 0); + + Array2D expected(2, 1024); + for (int i = 0; i < 1024; ++i) { + expected(0, i) = i; + expected(1, i) = i + 1024; + } + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { + Array2D lhs(1, 1024); + Array2D rhs(1, 1024); + for (int i = 0; i < 1024; ++i) { + lhs(0, i) = i; + rhs(0, i) = i + 1024; + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(lhs); + auto b = builder.ConstantR2FromArray2D(rhs); + builder.ConcatInDim({a, b}, 1); + + Array2D expected(1, 2048); + for (int i = 0; i < 1024; ++i) { + expected(0, i) = i; + expected(0, i + 1024) = i + 1024; + } + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { + Array2D lhs(64, 64); + Array2D rhs(64, 2); + for (int i0 = 0; i0 < 64; ++i0) { + for (int i1 = 0; i1 < 64; ++i1) { + lhs(i0, i1) = (i0 << 10) | i1; + } + for (int i1 = 0; i1 < 2; ++i1) { + rhs(i0, i1) = (i0 << 10) | (i1 + 64); + } + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(lhs); + auto b = builder.ConstantR2FromArray2D(rhs); + builder.ConcatInDim({a, b}, 1); + + Array2D expected(64, 66); + for (int i0 = 0; i0 < 64; ++i0) { + for (int i1 = 0; i1 < 66; ++i1) { + expected(i0, i1) = (i0 << 10) | i1; + } + } + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Show that we can't concatenate with an opaques. +XLA_TEST_F(ConcatTest, CannotConcatOpaques) { + ComputationBuilder builder(client_, TestName()); + auto opaque_shape = ShapeUtil::MakeOpaqueShape(); + auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); + auto x = builder.Parameter(0, r1f32, "x"); + auto y = builder.Parameter(1, opaque_shape, "y"); + auto concatenated = builder.ConcatInDim({x, y}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_MATCH( + computation_status.status().ToString(), + testing::ContainsRegex( + "Expected non-opaque argument for operand of concatenation")); +} + +XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { + ComputationBuilder builder(client_, TestName()); + auto p0 = builder.ConstantR1({true}); + auto p1 = builder.ConstantR1({false}); + auto p2 = builder.ConstantR1({true}); + auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0); + + bool expected[] = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { + ComputationBuilder builder(client_, TestName()); + auto a0 = builder.ConstantR1({1}); + auto a1 = builder.ConstantR1({2, 3}); + auto a2 = builder.ConstantR1({4, 5, 6}); + auto a3 = builder.ConstantR1({7, 8, 9, 10}); + auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0); + + std::vector expected(10); + std::iota(expected.begin(), expected.end(), 1); + ComputeAndCompareR1(&builder, expected, {}); +} + +// Describes a binary rank-2 concatenation test. +struct R2BinarySpec { + int64 lhs_dim0; + int64 lhs_dim1; + int64 rhs_dim0; + int64 rhs_dim1; + int64 concat_dimension; +}; + +// TEST_P harness for binary rank-2 concatenation. +class ConcatR2BinaryTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { +}; + +TEST_P(ConcatR2BinaryTest, DoIt) { + const R2BinarySpec& spec = GetParam(); + Array2D lhs(spec.lhs_dim0, spec.lhs_dim1); + lhs.FillUnique(); + Array2D rhs(spec.rhs_dim0, spec.rhs_dim1); + rhs.FillUnique(1000); + + ComputationBuilder builder(client_, TestName()); + auto a0 = builder.ConstantR2FromArray2D(lhs); + auto a1 = builder.ConstantR2FromArray2D(rhs); + builder.ConcatInDim({a0, a1}, spec.concat_dimension); + + std::unique_ptr> expected = + ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension); + ComputeAndCompareR2(&builder, *expected, {}); +} + +// Regression test for b/31944287. x*y is used (at the same index) by all +// operands of the concat. We should emit x*y in three incoming basic blocks of +// the concat because these basic blocks are not control-equivalent. +// +// x*y +// / | \ +// add1 add2 add3 +// \ | / +// concat +XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { + auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); + auto x_literal = LiteralUtil::CreateR0(2.f); + auto y_literal = LiteralUtil::CreateR0(3.f); + auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, f32_scalar, "x"); + auto y = builder.Parameter(1, f32_scalar, "y"); + auto mul = builder.Mul(x, y); + auto add1 = builder.Add(mul, builder.ConstantR1({1.f, 2.f})); + auto add2 = builder.Add(mul, builder.ConstantR1({3.f, 4.f})); + auto add3 = builder.Add(mul, builder.ConstantR1({5.f, 6.f})); + builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0); + + ComputeAndCompareR1(&builder, {7., 8., 9., 10., 11., 12.}, + {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); +} + +INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, + ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0}, + R2BinarySpec{1, 1, 1, 1, 1}, + R2BinarySpec{4, 3, 4, 3, 0}, + R2BinarySpec{4, 3, 4, 3, 1}, + R2BinarySpec{7, 128, 1, 128, 0}, + R2BinarySpec{8, 127, 8, 1, 1})); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc new file mode 100644 index 0000000000..58d52ac116 --- /dev/null +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -0,0 +1,193 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that constants in program memory round trip as expected. + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ConstantsTest : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{1e-3, 1e-5}; +}; + +TEST_F(ConstantsTest, ZeroCellF32) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1({}); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_F(ConstantsTest, OneCellF32) { + std::vector constant = {2.0}; + + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1(constant); + + ComputeAndCompareR1(&builder, constant, {}, error_spec_); +} + +TEST_F(ConstantsTest, OneCellS32) { + std::vector constant = {2}; + + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1(constant); + + ComputeAndCompareR1(&builder, constant, {}); +} + +TEST_F(ConstantsTest, OneCellU32) { + std::vector constant = {2}; + + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1(constant); + + ComputeAndCompareR1(&builder, constant, {}); +} + +TEST_F(ConstantsTest, EightCells) { + std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; + + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1(constant); + + ComputeAndCompareR1(&builder, constant, {}, error_spec_); +} + +TEST_F(ConstantsTest, SixteenCells) { + std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; + + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1(constant); + + ComputeAndCompareR1(&builder, constant, {}, error_spec_); +} + +TEST_F(ConstantsTest, Empty_0x2) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR2FromArray2D(Array2D(0, 2)); + + ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); +} + +TEST_F(ConstantsTest, Small_2x2) { + std::unique_ptr> constant = + MakeLinspaceArray2D(100.0, 200.0, 2, 2); + + ComputationBuilder builder(client_, TestName()); + builder.ConstantR2FromArray2D(*constant); + + ComputeAndCompareR2(&builder, *constant, {}, error_spec_); +} + +TEST_F(ConstantsTest, Empty_3x0x2) { + ComputationBuilder builder(client_, TestName()); + auto constant = builder.ConstantLiteral( + *LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2))); + + ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); +} + +TEST_F(ConstantsTest, Small_2x2x2) { + ComputationBuilder builder(client_, TestName()); + Array3D array3d({ + // x0 x1 + {{1.f, 2.f}, // y0 + {3.f, 4.f}}, // y1 + + {{5.f, 6.f}, // y0 + {7.f, 8.f}}, // y1 + }); + auto constant = builder.ConstantLiteral( + *LiteralUtil::CreateR3FromArray3D(array3d)); + + ComputeAndCompareR3(&builder, array3d, {}); +} + +TEST_F(ConstantsTest, Small_3x2x1x1) { + Array4D input_array(3, 2, 1, 1); + Array2D pz({ + // z0 z1 + {-1.0f, 4.1f}, // p0 + {2.0f, 4.1f}, // p1 + {5.0f, 4.4f}, // p2 + }); + input_array.FillWithPZ(pz); + Literal input_literal = *LiteralUtil::CreateR4FromArray4D(input_array); + + { + ComputationBuilder builder(client_, TestName()); + builder.ConstantLiteral(input_literal); + ComputeAndCompareR4(&builder, input_array, {}, error_spec_); + } + + { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR4FromArray4D(input_array); + ComputeAndCompareR4(&builder, input_array, {}, error_spec_); + } +} + +// TODO(b/29263943): Support tuple constants. +TEST_F(ConstantsTest, DISABLED_TupleConstant) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantLiteral(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()})); + + std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); + + LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, + result->tuple_literals(0), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, result->tuple_literals(1), + error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc new file mode 100644 index 0000000000..9f8c3a9aeb --- /dev/null +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ConvertTest : public ClientLibraryTestBase { + public: + explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) + : ClientLibraryTestBase(platform, + /*disabled_pass_names=*/{"algsimp", "inline"}) {} +}; + +TEST_F(ConvertTest, ConvertR1S32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42, 64}); + builder.ConvertElementType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.ConvertElementType(a, F32); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(ConvertTest, ConvertR1S32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42, 64}); + builder.ConvertElementType(a, F32); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + builder.ConvertElementType(a, F32); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.6, 64.4}); + builder.ConvertElementType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({32, 64}); + builder.ConvertElementType(a, F32); + + std::vector expected = {32.0, 64.0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({32, 64}); + builder.ConvertElementType(a, F32); + + std::vector expected = {32.0, 64.0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({32, 64}); + builder.ConvertElementType(a, S32); + + std::vector expected = {32, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({32, 64}); + builder.ConvertElementType(a, U32); + + std::vector expected = {32, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({32.0f, 64.0f}); + builder.ConvertElementType(a, F64); + + std::vector expected = {32.0, 64.0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({32.0, 64.0}); + builder.ConvertElementType(a, F32); + + std::vector expected = {32.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertS32Extremes) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {std::numeric_limits::min(), std::numeric_limits::max()}); + builder.ConvertElementType(a, F32); + + std::vector expected = { + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max())}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(ConvertTest, ConvertMapToS32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); + b->ConvertElementType(param, S32); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.Map({a}, b->BuildAndNoteError()); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertMapToF32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); + b->ConvertElementType(param, F32); + auto a = builder.ConstantR1({42, 64}); + builder.Map({a}, b->BuildAndNoteError()); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Regression test for b/31758660. When ReshapeMover transforms +// input -> reshape -> convert +// to +// input -> convert -> reshape +// the new convert should have the same element type as the old convert. +TEST_F(ConvertTest, ConvertReshape) { + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR1({42}); + auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + builder.ConvertElementType(reshape, F32); + + ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc new file mode 100644 index 0000000000..9f38dc4b36 --- /dev/null +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; + +// Tests the convolution operation with invalid input dimension numbers. +TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { + auto dimension_numbers_status = + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3); + ASSERT_FALSE(dimension_numbers_status.ok()); + ASSERT_MATCH(dimension_numbers_status.status().error_message(), + testing::ContainsRegex("input are not unique")); +} + +// Tests the convolution operation with invalid weight dimension numbers. +TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { + auto dimension_numbers_status = + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3); + ASSERT_FALSE(dimension_numbers_status.ok()); + ASSERT_MATCH(dimension_numbers_status.status().error_message(), + testing::ContainsRegex("weight are not unique")); +} + +XLA_TEST_F(ConvolutionDimensionNumbersTest, + TwoConvsWithDifferentDimensionNumbers) { + auto input_array = MakeUnique>(2, 3, 5, 5); + input_array->FillWithMultiples(0.1); + auto weight_array = MakeUnique>(4, 3, 1, 1); + weight_array->FillWithMultiples(0.2); + auto weight_data = + client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR4FromArray4D(*input_array); + auto weight = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); + auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); + + ConvolutionDimensionNumbers dim_nums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(); + // Swap batch_dimension and feature_dimension. + int64 tmp = dim_nums.batch_dimension(); + dim_nums.set_batch_dimension(dim_nums.feature_dimension()); + dim_nums.set_feature_dimension(tmp); + // Swap kernel_input_feature_dimension and kernel_output_feature_dimension. + tmp = dim_nums.kernel_input_feature_dimension(); + dim_nums.set_kernel_input_feature_dimension( + dim_nums.kernel_output_feature_dimension()); + dim_nums.set_kernel_output_feature_dimension(tmp); + builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, + dim_nums); + + auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array, + {1, 1}, Padding::kValid); + auto expected_conv2 = ReferenceUtil::ConvArray4DGeneralDimensions( + *input_array, *expected_conv1, {1, 1}, Padding::kValid, dim_nums); + + ComputeAndCompareR4(&builder, *expected_conv2, {weight_data.get()}, + ErrorSpec(0.001, 0.01)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc new file mode 100644 index 0000000000..ffbda89b94 --- /dev/null +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -0,0 +1,361 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests of convolution with trivial kernels and no special variations (like +// strides and padding). + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ConvolutionTest : public ClientLibraryTestBase { + protected: +#if XLA_TEST_BACKEND_GPU + // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial + // convolution. So relax the absolute error threshold. + ErrorSpec error_spec_ = ErrorSpec(1e-3); +#else + ErrorSpec error_spec_ = ErrorSpec(1e-4); +#endif +}; + +XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { + const int kInputActivationSizeY = 3; + const int kInputActivationSizeX = 3; + const int kInputActivationSizeZ = 256; + const int kKernelSizeX = 2; + const int kKernelSizeY = 2; + const int kOutputActivationSizeZ = 256; + const int kMiniBatchSize = 4; + auto alhs = + MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, + kInputActivationSizeY, kInputActivationSizeX); + alhs->FillWithMultiples(1.0f); + ASSERT_EQ(3, alhs->width()); + ASSERT_EQ(3, alhs->height()); + + auto arhs = + MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); + Array2D rhs_raster({ + {1.0f, 0.0f}, // row 0 + {0.0f, 0.0f}, // row 1 + }); + arhs->FillWithYX(rhs_raster); + ASSERT_EQ(2, arhs->width()); + ASSERT_EQ(2, arhs->height()); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR4FromArray4D(*alhs); + auto rhs = builder.ConstantR4FromArray4D(*arhs); + builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + + std::unique_ptr> aexpected = + ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid); + + ComputeAndCompareR4(&builder, *aexpected, {}, error_spec_); +} + +TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + } + + Array4D input(1, 1, 1, 2); + input.FillWithYX(Array2D({ + {1, 2}, + })); + Array4D filter(1, 1, 1, 2); + filter.FillWithYX(Array2D({ + {5, 6}, + })); + + std::unique_ptr> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); + + auto input_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +// Tests valid padding for 2D convolution in raster space. +TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + } + + Array4D input(1, 1, 4, 4); + // clang-format off + input.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + Array4D filter(1, 1, 2, 2); + // clang-format off + filter.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + + std::unique_ptr> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); + + auto input_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +// Tests same padding for 2D convolution in raster space. +TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kSame); + } + + Array4D input(1, 1, 4, 4); + // clang-format off + input.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + Array4D filter(1, 1, 2, 2); + // clang-format off + filter.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + + std::unique_ptr> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); + + auto input_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +// Tests same padding for 2D convolution in raster space with an odd sized +// kernel. +TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kSame); + } + + Array4D input(1, 1, 4, 4); + // clang-format off + input.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + Array4D filter(1, 1, 3, 3); + // clang-format off + filter.FillWithYX(Array2D({ + { 5, 6, 7}, + { 8, 9, 10}, + {11, 12, 13}, + })); + // clang-format on + + std::unique_ptr> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); + + auto input_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +// TODO(b/32873825): implement 1D convolution on GPU. +XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU(Convolve1D_1x2x5_1x2x2_Valid)) { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1}, Padding::kValid); + } + + Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); + Array3D filter({{{10, 20}, {30, 40}}}); + + Array3D expected({{{510, 610, 710, 810}}}); + + auto input_literal = + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +// TODO(b/32873825): implement 3D convolution on GPU. +XLA_TEST_F(ConvolutionTest, + DISABLED_ON_GPU(Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)) { + ComputationBuilder builder(client_, TestName()); + std::vector input_dims = {1, 4, 2, 3, 3}; + std::vector filter_dims = {2, 2, 2, 3, 3}; + Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); + Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); + { + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 3D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(2); + dnums.add_spatial_dimensions(3); + dnums.set_feature_dimension(4); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.add_kernel_spatial_dimensions(2); + dnums.set_kernel_input_feature_dimension(3); + dnums.set_kernel_output_feature_dimension(4); + + builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, + dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + std::iota(input_elems.begin(), input_elems.end(), 1.0f); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r5 = + LiteralUtil::Reshape(*input_r1, input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r5 = + LiteralUtil::Reshape(*filter_r1, filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, + 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); + auto expected_r5 = + LiteralUtil::Reshape(*expected_r1, {1, 3, 1, 2, 3}).ConsumeValueOrDie(); + + auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r5).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r5, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc new file mode 100644 index 0000000000..b599f9b95b --- /dev/null +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -0,0 +1,1294 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests of convolution variants -- kernel sizes, padding, and strides -- +// in small sized data. + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ConvolutionVariantsTest : public ClientLibraryTestBase { + protected: +#if XLA_TEST_BACKEND_GPU + // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial + // convolution. So relax the absolute error threshold. + ErrorSpec error_spec_ = ErrorSpec(1e-1, 1e-5); +#else + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-2); +#endif +}; + +TEST_F(ConvolutionVariantsTest, Minimal) { + ComputationBuilder builder(client_, TestName()); + + const Array4D input_array(1, 1, 1, 1, {2}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 1, {3}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + const Array4D expected(1, 1, 1, 1, {6}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { + ComputationBuilder builder(client_, TestName()); + + const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 1, {2}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + const Array4D expected(5, 1, 1, 1, {2, 4, 6, 8, 10}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Flat1x1) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(2, 1, 3, 4); + input_array.FillWithMultiples(1); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 1, {2.3}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(2, 1, 3, 4); + expected.FillWithMultiples(2.3); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Deep1x1) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 2, 1, 1, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 3, 1, 1, {12, 34, 56}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 2, {1, 2}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 1, 1, {12}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 3, {1, 2, 3}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 1, 2, {12, 23}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 2, 1, {12, 34}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 2, 1, {10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 1, 2, {13, 24}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 2, 2, {1000, 100, 10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 1, 1, {1234}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array( + 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1 + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array( + 2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected( + 2, 2, 2, 2, + {167, 1278, 3490, 4500, 0.0167, 0.1278, 0.3490, 0.4500, // plane 0 + 334, 2556, 6980, 9000, 0.0334, 0.2556, 0.6980, 0.9000}); // plane 1 + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 1, {10}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 2}, Padding::kValid); + + Array4D expected(1, 1, 1, 2, {10, 30}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 1, {10}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 2}, Padding::kValid); + + Array4D expected(1, 1, 1, 3, {10, 30, 50}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 2}, Padding::kValid); + + Array4D expected(1, 1, 1, 1, {123}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 2}, Padding::kValid); + + Array4D expected(1, 1, 1, 2, {123, 345}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 1, {10}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {2, 2}, Padding::kValid); + + Array4D expected(1, 1, 2, 2, {10, 30, 70, 90}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 1, {1}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 3, {10, 20, 30}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D expected(1, 1, 1, 1, {20}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 3, {1, 2, 3}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D expected(1, 1, 1, 3, {123, 1230, 12300}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 + 0, 100, 0, // row 1 + 10, 0, 1}); // row 2 + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D expected(1, 1, 2, 2, {104, 230, 2300, 10400}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 2, 1, 1, {10, 1}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D expected(1, 1, 1, 2, {13, 24}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 2, 2, {7, 13, 17, 23}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 2, 2, {216, 276, 396, 456}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(1, 1, 1, 3, {1, 2, 3}); + auto input = builder.ConstantR4FromArray4D(input_array); + + const Array4D filter_array(1, 1, 1, 2, {7, 13}); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 1, 1, 2, {33, 53}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(64); + std::iota(input_data.begin(), input_data.end(), 0.0); + Array4D input_array(1, 1, 8, 8, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(128); + std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0); + std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0); + const Array4D filter_array(2, 1, 8, 8, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 2, 1, 1, {2016, 4032}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(16 * 1 * 1 * 1); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(16, 1, 1, 1, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * 1 * 1); + std::iota(filter_data.begin(), filter_data.end(), 1.0); + const Array4D filter_array(1, 1, 1, 1, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::vector expected_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + Array4D expected(16, 1, 1, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { + ComputationBuilder builder(client_, TestName()); + + constexpr int bs = 16; + constexpr int kx = 2; + constexpr int ky = 2; + Array4D input_array(bs, 1, ky, kx); + for (int i0 = 0; i0 < bs; ++i0) { + for (int i2 = 0; i2 < ky; ++i2) { + for (int i3 = 0; i3 < kx; ++i3) { + input_array(i0, 0, i2, i3) = i0 + 1; + } + } + } + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * ky * kx); + std::iota(filter_data.begin(), filter_data.end(), 1.0); + const Array4D filter_array(1, 1, ky, kx, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::vector expected_data(bs); + for (int i = 0; i < bs; ++i) { + expected_data[i] = 10 * (i + 1); + } + Array4D expected(bs, 1, 1, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { + ComputationBuilder builder(client_, TestName()); + + constexpr int kx = 2; + constexpr int ky = 2; + constexpr int bs = 3; + Array4D input_array(bs, 1, ky, kx); + for (int i0 = 0; i0 < bs; ++i0) { + for (int i2 = 0; i2 < ky; ++i2) { + for (int i3 = 0; i3 < kx; ++i3) { + input_array(i0, 0, i2, i3) = i0 + i2 + i3 + 1; + } + } + } + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * ky * kx); + std::iota(filter_data.begin(), filter_data.end(), 1.0); + const Array4D filter_array(1, 1, ky, kx, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::vector expected_data = { + 23, 33, 43, + }; + Array4D expected(bs, 1, 1, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(16, 1, 8, 8); + for (int i0 = 0; i0 < 16; ++i0) { + for (int i2 = 0; i2 < 8; ++i2) { + for (int i3 = 0; i3 < 8; ++i3) { + input_array(i0, 0, i2, i3) = i0 + i2 + i3 + 1; + } + } + } + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * 8 * 8); + std::iota(filter_data.begin(), filter_data.end(), 1.0); + const Array4D filter_array(1, 1, 8, 8, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::vector expected_data = { + 19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224, + 36304, 38384, 40464, 42544, 44624, 46704, 48784, 50864, + }; + Array4D expected(16, 1, 1, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(2 * 8 * 8); + std::iota(input_data.begin(), input_data.end(), 0.0); + Array4D input_array(1, 2, 8, 8, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(2 * 2 * 8 * 8); + std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, + 1.0); + std::fill(filter_data.begin() + filter_data.size() / 4, + filter_data.begin() + filter_data.size() / 2, 2.0); + std::fill(filter_data.begin() + filter_data.size() / 2, + filter_data.begin() + 3 * filter_data.size() / 4, 3.0); + std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), + 4.0); + const Array4D filter_array(2, 2, 8, 8, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(1, 2, 1, 1, {14240, 30496}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(2 * 2 * 8 * 8); + std::iota(input_data.begin(), input_data.end(), 0.0); + Array4D input_array(2, 2, 8, 8, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(2 * 2 * 8 * 8); + std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, + 1.0); + std::fill(filter_data.begin() + filter_data.size() / 4, + filter_data.begin() + filter_data.size() / 2, 2.0); + std::fill(filter_data.begin() + filter_data.size() / 2, + filter_data.begin() + 3 * filter_data.size() / 4, 3.0); + std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), + 4.0); + const Array4D filter_array(2, 2, 8, 8, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(2, 2, 1, 1, {14240, 30496, 38816, 87840}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(32 * 2 * 8 * 8); + std::iota(input_data.begin(), input_data.end(), 0.0); + Array4D input_array(32, 2, 8, 8, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(2 * 2 * 8 * 8); + std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, + 1.0); + std::fill(filter_data.begin() + filter_data.size() / 4, + filter_data.begin() + filter_data.size() / 2, 2.0); + std::fill(filter_data.begin() + filter_data.size() / 2, + filter_data.begin() + 3 * filter_data.size() / 4, 3.0); + std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), + 4.0); + const Array4D filter_array(2, 2, 8, 8, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::vector expected_data = { + 14240, 30496, 38816, 87840, 63392, 145184, 87968, + 202528, 112544, 259872, 137120, 317216, 161696, 374560, + 186272, 431904, 210848, 489248, 235424, 546592, 260000, + 603936, 284576, 661280, 309152, 718624, 333728, 775968, + 358304, 833312, 382880, 890656, 407456, 948000, 432032, + 1005344, 456608, 1062688, 481184, 1120032, 505760, 1177376, + 530336, 1.23472e+06, 554912, 1292064, 579488, 1349408, 604064, + 1406752, 628640, 1464096, 653216, 1.52144e+06, 677792, 1578784, + 702368, 1636128, 726944, 1693472, 751520, 1750816, 776096, + 1.80816e+06, + }; + Array4D expected(32, 2, 1, 1, expected_data); + // The output elements can be larger than 1e+5, making the absolute error + // large sometimes. So, we focus on relative errors for this test case. + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { + ComputationBuilder builder(client_, TestName()); + + Array4D input_array(16, 16, 1, 1); + Array4D filter_array(16, 16, 1, 1); + for (int i0 = 0; i0 < 16; ++i0) { + for (int i1 = 0; i1 < 16; ++i1) { + input_array(i0, i1, 0, 0) = 1000 * i0 + i1; + filter_array(i0, i1, 0, 0) = 1; + } + } + + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D expected(16, 16, 1, 1); + for (int i0 = 0; i0 < 16; ++i0) { + for (int i1 = 0; i1 < 16; ++i1) { + expected(i0, i1, 0, 0) = 16000 * i0 + 120; + } + } + + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 4 * 6); + std::iota(input_data.begin(), input_data.end(), 0.0); + Array4D input_array(1, 1, 4, 6, input_data); + + Array4D filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneralDilated( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + Array4D expected(1, 1, 2, 2, {3924, 4257, 5922, 6255}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 1 * 5); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 1, 5, input_data); + + Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneralDilated( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, + /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + Array4D expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 3 * 4); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 3, 4, input_data); + + Array4D filter_array(1, 1, 4, 3, {100, 10, 1, // + 200, 20, 2, // + 300, 30, 3, // + 400, 40, 4}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneralDilated( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, + /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, + /*rhs_dilation=*/{}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + Array4D expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // + 1518, 180, 1821, 210, 2124, // + 4146, 460, 4651, 510, 5156}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 1 * 5); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 1, 5, input_data); + + Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneral( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, + /*padding=*/{{0, 0}, {-1, -1}}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + Array4D expected(1, 1, 1, 2, {23, 34}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 1 * 5); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 1, 5, input_data); + + Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneral( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, + /*padding=*/{{0, 0}, {-1, 2}}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + Array4D expected(1, 1, 1, 5, {23, 34, 45, 50, 0}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 1 * 5); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 1, 5, input_data); + + Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneral( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, + /*padding=*/{{0, 0}, {2, -1}}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + Array4D expected(1, 1, 1, 5, {0, 1, 12, 23, 34}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 1 * 5); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 1, 5, input_data); + + Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneralDilated( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, + /*padding=*/{{0, 0}, {3, 2}}, + /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + // input: + // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] + // ---pad---> [0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 0] + // filter: + // [10, 1] --dilate-> [10, 0, 1] + Array4D expected(1, 1, 1, 12, + {0, 1, 0, 12, 0, 23, 0, 34, 0, 45, 0, 50}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} +XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 1 * 1 * 5); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 1, 1, 5, input_data); + + Array4D filter_array(1, 1, 1, 2, {10, 1}); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.ConvGeneralDilated( + /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, + /*padding=*/{{0, 0}, {-3, -2}}, + /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + + // input: + // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] + // ---pad---> [0, 3, 0, 4] + // filter: + // [10, 1] --dilate-> [10, 0, 1] + Array4D expected(1, 1, 1, 2, {0, 34}); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { + constexpr int bs = 1; + constexpr int iz = 1; + constexpr int oz = 2; + constexpr int iy = 2; + constexpr int ix = 3; + constexpr int ky = 1; + constexpr int kx = 2; + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector input_data(bs * iz * iy * ix); + for (float& f : input_data) { + f = distribution(rng); + } + std::vector kernel_data(oz * iz * ky * kx); + for (float& f : kernel_data) { + f = distribution(rng); + } + + Array4D input_array(bs, iz, iy, ix, input_data); + Array4D filter_array(oz, iz, ky, kx, kernel_data); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::unique_ptr> expected = ReferenceUtil::ConvArray4D( + input_array, filter_array, {1, 1}, Padding::kValid); + + ComputeAndCompareR4(&builder, *expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { + constexpr int bs = 1; + constexpr int iz = 16; + constexpr int oz = 1; + constexpr int iy = 1; + constexpr int ix = 1; + constexpr int ky = 1; + constexpr int kx = 1; + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector input_data(bs * iz * iy * ix); + for (float& f : input_data) { + f = distribution(rng); + } + std::vector kernel_data(oz * iz * ky * kx); + for (float& f : kernel_data) { + f = distribution(rng); + } + + Array4D input_array(bs, iz, iy, ix, input_data); + Array4D filter_array(oz, iz, ky, kx, kernel_data); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::unique_ptr> expected = ReferenceUtil::ConvArray4D( + input_array, filter_array, {1, 1}, Padding::kValid); + + ComputeAndCompareR4(&builder, *expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { + constexpr int bs = 16; + constexpr int iz = 16; + constexpr int oz = 1; + constexpr int iy = 1; + constexpr int ix = 1; + constexpr int ky = 1; + constexpr int kx = 1; + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector input_data(bs * iz * iy * ix); + for (float& f : input_data) { + f = distribution(rng); + } + std::vector kernel_data(oz * iz * ky * kx); + for (float& f : kernel_data) { + f = distribution(rng); + } + + Array4D input_array(bs, iz, iy, ix, input_data); + Array4D filter_array(oz, iz, ky, kx, kernel_data); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::unique_ptr> expected = ReferenceUtil::ConvArray4D( + input_array, filter_array, {1, 1}, Padding::kValid); + + ComputeAndCompareR4(&builder, *expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { + constexpr int bs = 16; + constexpr int iz = 16; + constexpr int oz = 16; + constexpr int iy = 1; + constexpr int ix = 1; + constexpr int ky = 1; + constexpr int kx = 1; + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector input_data(bs * iz * iy * ix); + for (float& f : input_data) { + f = distribution(rng); + } + std::vector kernel_data(oz * iz * ky * kx); + for (float& f : kernel_data) { + f = distribution(rng); + } + + Array4D input_array(bs, iz, iy, ix, input_data); + Array4D filter_array(oz, iz, ky, kx, kernel_data); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::unique_ptr> expected = ReferenceUtil::ConvArray4D( + input_array, filter_array, {1, 1}, Padding::kValid); + + ComputeAndCompareR4(&builder, *expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) { + constexpr int bs = 16; + constexpr int iz = 16; + constexpr int oz = 16; + constexpr int iy = 16; + constexpr int ix = 16; + constexpr int ky = 16; + constexpr int kx = 16; + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector input_data(bs * iz * iy * ix); + for (float& f : input_data) { + f = distribution(rng); + } + std::vector kernel_data(oz * iz * ky * kx); + for (float& f : kernel_data) { + f = distribution(rng); + } + + Array4D input_array(bs, iz, iy, ix, input_data); + Array4D filter_array(oz, iz, ky, kx, kernel_data); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR4FromArray4D(input_array); + auto filter = builder.ConstantR4FromArray4D(filter_array); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + std::unique_ptr> expected = ReferenceUtil::ConvArray4D( + input_array, filter_array, {1, 1}, Padding::kValid); + + ComputeAndCompareR4(&builder, *expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 2 * 3 * 1); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 2, 3, 1, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 2 * 1 * 1); + std::iota(filter_data.begin(), filter_data.end(), 1.0); + Array4D filter_array(1, 2, 1, 1, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + ConvolutionDimensionNumbers dnums; + // NHWC input format. + dnums.set_batch_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(2); + dnums.set_feature_dimension(3); + + // Tensorflow filter shape: [ H, W, inC, outC ] + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + // Tests padding sizes that don't correspond either to SAME or VALID padding. + builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + + std::vector expected_data = { + 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, // + 0, 2, 5, 8, 3, 0, 0, // + 0, 8, 14, 17, 6, 0, 0, // + 0, 0, 0, 0, 0, 0, 0 // + }; + Array4D expected(1, 5, 7, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 2 * 3 * 1); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 2, 3, 1, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * 1 * 1); + std::iota(filter_data.begin(), filter_data.end(), 2.0); + Array4D filter_array(1, 1, 1, 1, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + ConvolutionDimensionNumbers dnums; + // NHWC input format. + dnums.set_batch_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(2); + dnums.set_feature_dimension(3); + + // Tensorflow filter shape: [ H, W, inC, outC ] + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + // Tests padding sizes that don't correspond either to SAME or VALID padding. + builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + + std::vector expected_data = { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 2, 4, 6, 0, 0, 0, // + 0, 0, 8, 10, 12, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0 // + }; + Array4D expected(1, 5, 8, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 2 * 3 * 1); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 2, 3, 1, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * 1 * 1); + std::iota(filter_data.begin(), filter_data.end(), 2.0); + Array4D filter_array(1, 1, 1, 1, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + ConvolutionDimensionNumbers dnums; + // NHWC input format. + dnums.set_batch_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(2); + dnums.set_feature_dimension(3); + + // Tensorflow filter shape: [ H, W, inC, outC ] + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + // Tests zero padding sizes. This can use matmul for computation. + builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + + std::vector expected_data = { + 2, 4, 6, // + 8, 10, 12, + }; + Array4D expected(1, 2, 3, 1, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { + ComputationBuilder builder(client_, TestName()); + + std::vector input_data(1 * 2 * 3 * 2); + std::iota(input_data.begin(), input_data.end(), 1.0); + Array4D input_array(1, 2, 3, 2, input_data); + auto input = builder.ConstantR4FromArray4D(input_array); + + std::vector filter_data(1 * 1 * 2 * 3); + std::iota(filter_data.begin(), filter_data.end(), 2.0); + Array4D filter_array(1, 1, 2, 3, filter_data); + auto filter = builder.ConstantR4FromArray4D(filter_array); + + ConvolutionDimensionNumbers dnums; + // NHWC input format. + dnums.set_batch_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(2); + dnums.set_feature_dimension(3); + + // Tensorflow filter shape: [ H, W, inC, outC ] + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + // Tests zero padding sizes. This can use matmul for computation. + builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + + std::vector expected_data = { + 12, 15, 18, // + 26, 33, 40, // + 40, 51, 62, // + 54, 69, 84, // + 68, 87, 106, // + 82, 105, 128, // + }; + Array4D expected(1, 2, 3, 3, expected_data); + ComputeAndCompareR4(&builder, expected, {}, error_spec_); +} + +// Regression test for b/32034796. +// +// XLA:GPU fuses +// Conv([1,2,3], Reverse([5,6]), padding_low=1) +// into +// BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) +TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { + ComputationBuilder builder(client_, TestName()); + + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 2, /*values=*/{5, 6})); + auto mirrored_weights = builder.Rev(weights, {2, 3}); + builder.ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 0}}); + ComputeAndCompareR4(&builder, {{{{5, 16, 27}}}}, {}, error_spec_); +} + +// XLA:GPU fuses +// Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3) +// into +// BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) +TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { + ComputationBuilder builder(client_, TestName()); + + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = builder.Rev(weights, {2, 3}); + builder.ConvGeneralDilated( + gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); +} + +// XLA:GPU fuses +// Conv([1], Reverse([1,10,100]), padding=(1,1)) +// into +// BackwardInputConv([1], [1,10,100], padding=(1,1)) +TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { + ComputationBuilder builder(client_, TestName()); + + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = builder.Rev(weights, {2, 3}); + builder.ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 1}}); + ComputeAndCompareR4(&builder, {{{{10}}}}, {}, error_spec_); +} + +// HLO pattern +// Conv([1,2,3], Reverse([1,10], padding_high=2) +// could be fused to +// BackwardInputConv([1,2,3], [1,10], padding_low=1, padding_high=-1) +// +// However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't +// support negative padding on backward convolution yet (b/32744257). +TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { + ComputationBuilder builder(client_, TestName()); + + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 2, /*values=*/{1, 10})); + auto mirrored_weights = builder.Rev(weights, {2, 3}); + builder.ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 2}}); + + ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { + ComputationBuilder builder(client_, TestName()); + + // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 + // gradients: 100,10,1 -dilate-> 100,0,10,0,1 + // weight gradients: 24,130,240 + // + // This pattern will be fused to backward convolution with padding=(1,2). + auto activations = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = builder.ConvGeneralDilated( + activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 2}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + builder.Transpose(forward_conv, {0, 1, 2, 3}); + + ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, + BackwardFilterLowPaddingGreaterThanHighPadding) { + ComputationBuilder builder(client_, TestName()); + + // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 + // gradients: 100,10,1 -dilate-> 100,0,10,0,1 + // weight gradients: 13,24 + // + // This pattern will be fused to backward convolution with padding=(2,1). + // Note: both (2,1) and (2,0) are valid padding for the backward convolution + // because the stride is 2. + auto activations = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = builder.ConvGeneralDilated( + activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 0}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + builder.Transpose(forward_conv, {0, 1, 2, 3}); + + ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { + ComputationBuilder builder(client_, TestName()); + + // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0 + // gradients: 100,10,1 -dilate-> 100,0,10,0,1 + // weight gradients: 13,24,130 + // + // This pattern will be fused to backward convolution with padding=(2,2). + // Note: both (2,1) and (2,2) are valid padding for the backward convolution + // because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN + // supports even padding only -- using (2,1) would need extra effort of + // canonicalization. + auto activations = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = builder.ConstantR4FromArray4D( + Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = builder.ConvGeneralDilated( + activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers()); + builder.Transpose(forward_conv, {0, 1, 2, 3}); + + ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc new file mode 100644 index 0000000000..29e2950533 --- /dev/null +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class CopyOpTest : public HloTestBase { + protected: + void TestCopyOp(const Literal& literal) { + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(MakeUnique(literal))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + auto computation = builder.Build(); + auto hlo_module = MakeUnique("test_module"); + hlo_module->AddEntryComputation(std::move(computation)); + + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {}); + LiteralTestUtil::ExpectEqual(literal, *result); + } + + void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); + void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, + tensorflow::gtl::ArraySlice permutation); +}; + +TEST_F(CopyOpTest, CopyR0Bool) { + TestCopyOp(*LiteralUtil::CreateR0(true)); +} + +TEST_F(CopyOpTest, CopyR1S0U32) { + TestCopyOp(*LiteralUtil::CreateR1({})); +} + +TEST_F(CopyOpTest, CopyR1S3U32) { + TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); +} + +TEST_F(CopyOpTest, CopyR3F32_2x2x3) { + TestCopyOp( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { + TestCopyOp(*LiteralUtil::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { + TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); +} + +TEST_F(CopyOpTest, CopyParameterScalar) { + auto builder = HloComputation::Builder(TestName()); + + // Copy literal to device to use as parameter. + auto literal = LiteralUtil::CreateR0(42.0); + Shape shape = literal->shape(); + auto constant_device_base = TransferToDevice(*literal); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0)); + + auto computation = builder.Build(); + + auto hlo_module = MakeUnique("test_module"); + hlo_module->AddEntryComputation(std::move(computation)); + + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {constant_device_base}); + LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); +} + +TEST_F(CopyOpTest, CopyConstantR2Twice) { + auto builder = HloComputation::Builder(TestName()); + + auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + auto copy = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + builder.AddInstruction( + HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, copy)); + + auto computation = builder.Build(); + + auto hlo_module = MakeUnique("test_module"); + hlo_module->AddEntryComputation(std::move(computation)); + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {}); + LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, + error_spec_); +} + +TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + // Reverse the minor-to-major order of the literal. + Layout* literal_layout = literal->mutable_shape()->mutable_layout(); + ASSERT_EQ(2, literal_layout->minor_to_major_size()); + literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = MakeUnique("test_module"); + hlo_module->AddEntryComputation(std::move(computation)); + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {}); + + // The result of the computation has the default layout, which is the inverse + // of the layout of the source literal. + LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, *result, + error_spec_); +} + +void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { + Array3D a(n1, n2, n3); + for (size_t i = 0; i < n1; ++i) { + for (size_t j = 0; j < n2; ++j) { + for (size_t k = 0; k < n3; ++k) { + a(i, j, k) = i * n3 * n2 + j * n3 + k; + } + } + } + + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = MakeUnique("test_module"); + auto config = MakeUnique(computation->ComputeProgramShape()); + *config->mutable_entry_computation_layout()->mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeShapeWithLayout( + constant->shape().element_type(), + AsInt64Slice(constant->shape().dimensions()), {1, 2, 0})); + hlo_module->AddEntryComputation(std::move(computation)); + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), std::move(config), {}); + + LiteralTestUtil::ExpectR3EqualArray3D(a, *result); +} + +void CopyOpTest::TestCopyConstantLayoutR4( + size_t n1, size_t n2, size_t n3, size_t n4, + tensorflow::gtl::ArraySlice permutation) { + Array4D a(n1, n2, n3, n4); + for (size_t i = 0; i < n1; ++i) { + for (size_t j = 0; j < n2; ++j) { + for (size_t k = 0; k < n3; ++k) { + for (size_t l = 0; l < n4; ++l) { + a(i, j, k, l) = i * n4 * n3 * n2 + j * n4 * n3 + k * n4 + l; + } + } + } + } + + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = MakeUnique("test_module"); + auto config = MakeUnique(computation->ComputeProgramShape()); + *config->mutable_entry_computation_layout()->mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeShapeWithLayout( + constant->shape().element_type(), + AsInt64Slice(constant->shape().dimensions()), ({ + std::vector p(permutation.rbegin(), permutation.rend()); + p; + }))); + hlo_module->AddEntryComputation(std::move(computation)); + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), std::move(config), {}); + + LiteralTestUtil::ExpectR4EqualArray4D(a, *result); +} + +XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) { + TestCopyConstantLayout021(2, 2, 3); +} + +XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleCompleteTilePerLayer) { + TestCopyConstantLayout021(2, 32, 32); +} + +XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_MultipleTilesPerLayer) { + TestCopyConstantLayout021(2, 70, 35); +} + +XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0231_MultipleTilesPerLayer) { + TestCopyConstantLayoutR4(2, 70, 7, 5, {0, 2, 3, 1}); +} + +XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0312_MultipleTilesPerLayer) { + TestCopyConstantLayoutR4(2, 14, 5, 35, {0, 3, 1, 2}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc new file mode 100644 index 0000000000..dc54c9defe --- /dev/null +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -0,0 +1,148 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/test.h" + +extern "C" void __attribute__((visibility("default"))) +R0F32Add2(float* out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); + *out = **in + 2.0f; +} + +extern "C" void __attribute__((visibility("default"))) +R2F32ReduceSum(float* out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); + float* array = in[0]; + *out = array[0] + array[1] + array[2] + array[3]; +} + +extern "C" void __attribute__((visibility("default"))) +Add1ToValues(float* out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); + float* array = in[0]; + out[0] = array[0] + 1; + out[1] = array[1] + 1; + out[2] = array[2] + 1; + out[3] = array[3] + 1; +} + +namespace xla { +namespace { + +class CustomCallTest : public HloTestBase { + protected: + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); +}; + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { + auto hlo_module = MakeUnique("test_module"); + auto builder = HloComputation::Builder(TestName()); + + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); + + hlo_module->AddEntryComputation(builder.Build()); + + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {}); + LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { + auto hlo_module = MakeUnique("test_module"); + auto builder = HloComputation::Builder(TestName()); + + Array2D array(2, 2); + array(0, 0) = 1.0f; + array(0, 1) = 2.0f; + array(1, 0) = 3.0f; + array(1, 1) = 4.0f; + + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); + builder.AddInstruction( + HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); + + hlo_module->AddEntryComputation(builder.Build()); + + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {}); + LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); +} + +XLA_TEST_F(CustomCallTest, + DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { + auto hlo_module = MakeUnique("test_module"); + auto b = HloComputation::Builder(TestName()); + + auto input = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); + auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); + auto incremented_again = b.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues")); + + // Concatenate the values along first dim. + b.AddInstruction( + HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}), + {incremented, incremented_again}, 0)); + + hlo_module->AddEntryComputation(b.Build()); + + std::unique_ptr result = + ExecuteAndTransfer(std::move(hlo_module), {}); + LiteralTestUtil::ExpectR3EqualArray3D( + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc new file mode 100644 index 0000000000..528efd2942 --- /dev/null +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DeallocationTest : public ClientLibraryTestBase { + protected: + // Build and execute the given computation then verify the results can be + // transferred from the device successfully. + std::unique_ptr ExecuteAndCheckTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments) { + Computation computation = builder->Build().ConsumeValueOrDie(); + auto global_data = + client_->Execute(computation, arguments).ConsumeValueOrDie(); + TF_CHECK_OK(client_->Transfer(*global_data).status()); + return global_data; + } +}; + +TEST_F(DeallocationTest, DeallocateScalar) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR0(42.0); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + // A result can be transfered an arbitrary number of times. Add an extra + // transfer here so we're not just testing that a second call to Transfer + // fails. + ASSERT_IS_OK(client_->Transfer(*global_data).status()); + + ASSERT_IS_OK(client_->Unregister(*global_data)); + + auto transfer_status = client_->Transfer(*global_data); + ASSERT_FALSE(transfer_status.ok()); + ASSERT_MATCH(transfer_status.status().error_message(), + testing::HasSubstr("was previously deallocated")); +} + +TEST_F(DeallocationTest, DeallocateVector) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + ASSERT_IS_OK(client_->Unregister(*global_data)); + + auto transfer_status = client_->Transfer(*global_data); + ASSERT_FALSE(transfer_status.ok()); + ASSERT_MATCH(transfer_status.status().error_message(), + testing::HasSubstr("was previously deallocated")); +} + +TEST_F(DeallocationTest, DeallocateEmptyVector) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1({}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + ASSERT_IS_OK(client_->Unregister(*global_data)); + + auto transfer_status = client_->Transfer(*global_data); + ASSERT_FALSE(transfer_status.ok()); + ASSERT_MATCH(transfer_status.status().error_message(), + testing::HasSubstr("was previously deallocated")); +} + +XLA_TEST_F(DeallocationTest, DeallocateTuple) { + ComputationBuilder builder(client_, TestName()); + builder.Tuple({builder.ConstantR0(42.0), + builder.ConstantR1({1.0, 2.0, 3.0})}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + ASSERT_IS_OK(client_->Unregister(*global_data)); + + auto transfer_status = client_->Transfer(*global_data); + ASSERT_FALSE(transfer_status.ok()); + ASSERT_MATCH(transfer_status.status().error_message(), + testing::HasSubstr("was previously deallocated")); +} + +XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { + ComputationBuilder builder(client_, TestName()); + auto element = builder.ConstantR0(42.0); + auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); + builder.Tuple({element, inner_tuple, element}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + ASSERT_IS_OK(client_->Unregister(*global_data)); + + auto transfer_status = client_->Transfer(*global_data); + ASSERT_FALSE(transfer_status.ok()); + ASSERT_MATCH(transfer_status.status().error_message(), + testing::HasSubstr("was previously deallocated")); +} + +XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { + ComputationBuilder builder(client_, TestName()); + auto inner_tuple = + builder.Tuple({builder.ConstantR0(42.0), + builder.ConstantR1({1.0, 2.0, 3.0})}); + builder.Tuple({inner_tuple, builder.ConstantR1({0.123, 0.456})}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + ASSERT_IS_OK(client_->Unregister(*global_data)); + + auto transfer_status = client_->Transfer(*global_data); + ASSERT_FALSE(transfer_status.ok()); + ASSERT_MATCH(transfer_status.status().error_message(), + testing::HasSubstr("was previously deallocated")); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc new file mode 100644 index 0000000000..57a7c61b14 --- /dev/null +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DeconstructTupleTest : public ClientLibraryTestBase { + protected: + // Build and execute the given computation then verify the results can be + // transferred from the device successfully. + std::unique_ptr ExecuteAndCheckTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments) { + Computation computation = builder->Build().ConsumeValueOrDie(); + auto global_data = + client_->Execute(computation, arguments).ConsumeValueOrDie(); + TF_CHECK_OK(client_->Transfer(*global_data).status()); + return global_data; + } +}; + +TEST_F(DeconstructTupleTest, DeconstructTuple) { + ComputationBuilder builder(client_, TestName()); + auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); + builder.Tuple({const1, const2}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + auto result_status = client_->DeconstructTuple(*global_data); + EXPECT_TRUE(result_status.ok()); + + // Try copying the elements back and comparing it + auto handles = result_status.ConsumeValueOrDie(); + std::vector copy(4); + ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); +} + +TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { + ComputationBuilder builder(client_, TestName()); + auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); + builder.Tuple({const1, const2}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + auto result_status1 = client_->DeconstructTuple(*global_data); + EXPECT_TRUE(result_status1.ok()); + auto result_status2 = client_->DeconstructTuple(*global_data); + EXPECT_TRUE(result_status2.ok()); + + auto handles1 = result_status1.ConsumeValueOrDie(); + auto handles2 = result_status2.ConsumeValueOrDie(); + std::vector copy(4); + + ASSERT_IS_OK(client_->TransferInProcess(*handles1[0], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles1[1], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + handles1[0].reset(); + handles1[1].reset(); + + ASSERT_IS_OK(client_->TransferInProcess(*handles2[0], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles2[1], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); +} + +XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { + ComputationBuilder builder(client_, TestName()); + auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); + builder.Tuple({const1, const2, const2, const1}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + auto result_status = client_->DeconstructTuple(*global_data); + EXPECT_TRUE(result_status.ok()); + + // Verify the returned GlobalDataHandle arrays have repeated elements like the + // tuple does. That is, in the returned vector of handles, handle[0] should be + // the same as handle[3] and handle[1] should be the same as handle[2]. + auto handles = result_status.ConsumeValueOrDie(); + + std::vector copy(4); + ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles[3], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); +} + +TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { + ComputationBuilder builder(client_, TestName()); + auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); + builder.Tuple({const1, const2, const1}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + auto result_status = client_->DeconstructTuple(*global_data); + EXPECT_TRUE(result_status.ok()); + auto handles = result_status.ConsumeValueOrDie(); + + // Deallocate the tuple, then try copying the elements back. The elements + // should not have been deallocated because of reference counting. + global_data.reset(); + + std::vector copy(4); + ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + + /// Try deallocating one of the repeated elements, then copy + handles[0].reset(); + + ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); + EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); +} + +TEST_F(DeconstructTupleTest, DeconstructNonTuple) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + auto result_status = client_->DeconstructTuple(*global_data); + EXPECT_FALSE(result_status.ok()); + EXPECT_MATCH(result_status.status().error_message(), + testing::ContainsRegex("global data handle .* is not a tuple")); +} + +XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({3.14f, -100.25f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); + builder.Tuple({p}); + auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); + + auto result_status = client_->DeconstructTuple(*global_data); + EXPECT_TRUE(result_status.ok()); + auto handles = result_status.ConsumeValueOrDie(); + EXPECT_NE(handles[0]->handle().handle(), param0_data->handle().handle()); +} + +XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { + ComputationBuilder builder(client_, TestName()); + auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); + builder.Tuple({builder.Tuple({const1, const2}), const1}); + auto global_data = ExecuteAndCheckTransfer(&builder, {}); + + auto result_status = client_->DeconstructTuple(*global_data); + EXPECT_FALSE(result_status.ok()); + EXPECT_MATCH( + result_status.status().error_message(), + testing::ContainsRegex("deconstructing nested tuples not yet supported")); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc new file mode 100644 index 0000000000..da2d43ca4f --- /dev/null +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -0,0 +1,387 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace { + +// TODO(mfdyck): use GUnit typed tests when we can do all tests on all backends. +class DotOperationTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001, 1e-5}; + + protected: + template + void TestOneElementVectorDot(); + template + void TestVectorDot(); + template + void TestSquareMatrixDot(bool lhs_row_major = false, + bool rhs_row_major = false); + template + void TestNonsquareMatrixDot(bool lhs_row_major = false, + bool rhs_row_major = false); +}; + +XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR0(&builder, 0.0, {}, error_spec_); +} + +template +void DotOperationTest::TestOneElementVectorDot() { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({2.0}); + auto rhs = builder.ConstantR1({3.0}); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR0(&builder, 6.0, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) { + TestOneElementVectorDot(); +} + +XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) { + TestOneElementVectorDot(); +} + +template +void DotOperationTest::TestVectorDot() { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({1.0, 2.5, 42.0}); + auto rhs = builder.ConstantR1({11.0, -1.0, 0.5}); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR0(&builder, 29.5, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot(); } + +XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot(); } + +namespace { + +std::vector MinorToMajorForIsRowMajor(bool row_major) { + return {row_major ? 1 : 0, row_major ? 0 : 1}; +} + +} // namespace + +XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}}); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR2(&builder, Array2D(0, 3), {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) { + ComputationBuilder builder(client_, TestName()); + auto lhs = + builder.ConstantR2({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}}); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR2(&builder, Array2D(3, 0), {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR2(&builder, Array2D(2, 2, 0.0f), {}, + error_spec_); +} + +template +void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, + bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + {{1.0, 2.0}, {3.0, -4.0}}, + MinorToMajorForIsRowMajor(lhs_row_major))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + {{1.0, 6.0}, {7.0, -4.0}}, + MinorToMajorForIsRowMajor(rhs_row_major))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + + Array2D expected({{15.0, -2.0}, {-25.0, 34.0}}); + ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { + constexpr bool kLhsRowMajor = false; + constexpr bool kRhsRowMajor = false; + TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { + TestSquareMatrixDot(false, true); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { + TestSquareMatrixDot(true, false); +} + +TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { + constexpr bool kLhsRowMajor = true; + constexpr bool kRhsRowMajor = true; + TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { + TestSquareMatrixDot(); +} + +template +void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, + bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, + MinorToMajorForIsRowMajor(lhs_row_major))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, + MinorToMajorForIsRowMajor(rhs_row_major))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + + Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); + + ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { + constexpr bool kLhsRowMajor = false; + constexpr bool kRhsRowMajor = false; + TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { + constexpr bool kLhsRowMajor = false; + constexpr bool kRhsRowMajor = true; + TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { + constexpr bool kLhsRowMajor = true; + constexpr bool kRhsRowMajor = false; + TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +} + +TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { + constexpr bool kLhsRowMajor = true; + constexpr bool kRhsRowMajor = true; + TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { + TestNonsquareMatrixDot(); +} + +TEST_F(DotOperationTest, ConcurrentMatMul) { + ComputationBuilder builder(client_, TestName()); + auto matrix1 = builder.ConstantR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix2 = builder.ConstantR2({{5.0, 6.0}, {7.0, 8.0}}); + auto matrix12 = builder.Dot(matrix1, matrix2); + auto matrix21 = builder.Dot(matrix2, matrix1); + builder.Add(matrix12, matrix21); + + Array2D expected({{42.0, 56.0}, {74.0, 96.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// Regression test for b/32055648. The root of the graph is a kFusion of 4 +// bitcasts. Although bitcasts don't map to thunks, the root should still be +// sync-dependent on bitcasts' operands. +XLA_TEST_F(DotOperationTest, BatchMatMul) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y"); + + auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); + auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); + + // Slice batches into individual matrices and multiply them. + std::vector out_slices; + for (int i = 0; i < 4; ++i) { + // Slice off individual matrices and reshape to 2D tensors. + auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}); + x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); + auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}); + y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); + + auto out = builder.Dot(x_slice, y_slice); + out = builder.Reshape(out, {0, 1}, {1, 2, 2}); + out_slices.push_back(out); + } + auto out_flat = builder.ConcatInDim(out_slices, 0); + builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); + + auto x_data = client_ + ->TransferToServer(*LiteralUtil::CreateR4( + {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, + {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) + .ConsumeValueOrDie(); + auto y_data = client_ + ->TransferToServer(*LiteralUtil::CreateR4( + {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, + {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) + .ConsumeValueOrDie(); + + ComputeAndCompareR4( + &builder, + /*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, + {{{42900, 79200}, {429, 792}}, + {{250800, 299200}, {2508, 2992}}}}, + {x_data.get(), y_data.get()}, error_spec_); +} + +TEST_F(DotOperationTest, TransposeFolding) { + for (bool transpose_lhs : {false, true}) { + for (bool transpose_rhs : {false, true}) { + for (bool row_major : {false, true}) { + std::unique_ptr> lhs( + new Array2D({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}})); + std::unique_ptr> rhs( + new Array2D({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}})); + + if (transpose_lhs) { + lhs = ReferenceUtil::TransposeArray2D(*lhs); + } + if (transpose_rhs) { + rhs = ReferenceUtil::TransposeArray2D(*rhs); + } + auto lhs_handle = + client_ + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto lhs_arg = builder.Parameter( + 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), + "lhs"); + auto rhs_arg = builder.Parameter( + 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), + "rhs"); + if (transpose_lhs) { + lhs_arg = builder.Transpose(lhs_arg, {1, 0}); + } + if (transpose_rhs) { + rhs_arg = builder.Transpose(rhs_arg, {1, 0}); + } + auto result = builder.Dot(lhs_arg, rhs_arg); + + Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); + VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " + << transpose_rhs << " " << row_major; + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + error_spec_); + } + } + } +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendLayoutUtilFlags(&flag_list); + xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc new file mode 100644 index 0000000000..cecc4872df --- /dev/null +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace { + +class DynamicSliceTest : public ClientLibraryTestBase { + protected: + template + void TestR1() { + // Slice at dimension start. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {0}, {5}, + {0.0, 1.0, 2.0, 3.0, 4.0}); + // Slice in the middle. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {3}, + {2.0, 3.0, 4.0}); + // Slice at dimension boundaries. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {5}, {3}, + {5.0, 6.0, 7.0}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4}, + {6.0, 7.0, 0.0, 1.0}); + } + + template + void TestR2() { + // Slice at dimension start. + RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {0, 0}, {2, 2}, {{1.0f, 2.0f}, {4.0f, 5.0f}}); + // Slice in the middle. + RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {1, 1}, {2, 1}, {{5.0f}, {8.0f}}); + // Slice at dimension boundaries. + RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {1, 1}, {2, 1}, {{5.0f}, {8.0f}}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {1, 1}, {3, 3}, + {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}}); + } + + template + void TestR3() { + // R3 Shape: [2, 3, 2] + // clang-format off + + // Slice at dimension start. + RunR3( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {0, 0, 0}, {2, 1, 2}, + {{{1.0f, 2.0f}}, {{7.0f, 8.0f}}}); + + // Slice in the middle. + RunR3( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {0, 1, 1}, {2, 2, 1}, + {{{4.0f}, {6.0f}}, {{10.0f}, {12.0f}}}); + + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR3( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {0, 2, 1}, {2, 2, 1}, + {{{6.0f}, {2.0f}}, {{12.0f}, {8.0f}}}); + + // clang-format on + } + + template + void RunR1(const std::vector& input_values, + const std::vector slice_starts, + const std::vector slice_sizes, + const std::vector& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR1(input_values); + builder.DynamicSlice(input, starts, slice_sizes); + // Run computation and compare against expected values. + ComputeAndCompareR1(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template + void RunR2(const Array2D& input_values, + const std::vector slice_starts, + const std::vector slice_sizes, + const Array2D& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR2FromArray2D(input_values); + builder.DynamicSlice(input, starts, slice_sizes); + // Run computation and compare against expected values. + ComputeAndCompareR2(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template + void RunR3(const Array3D& input_values, + const std::vector slice_starts, + const std::vector slice_sizes, + const Array3D& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR3FromArray3D(input_values); + builder.DynamicSlice(input, starts, slice_sizes); + // Run computation and compare against expected values. + ComputeAndCompareR3(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } +}; + +XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } + +XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } + +XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } + +XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } + +XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } + +XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } + +class DynamicUpdateSliceTest : public ClientLibraryTestBase { + protected: + template + void TestR1() { + // clang-format off + // Slice at dimension start. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {0}, + {8.0, 9.0, 10.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + // Slice in the middle. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {2}, + {0.0, 1.0, 8.0, 9.0, 10.0, 5.0, 6.0, 7.0}); + // Slice at dimension boundaries. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {5}, + {0.0, 1.0, 2.0, 3.0, 4.0, 8.0, 9.0, 10.0}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {6}, + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0}); + // clang-format on + } + + template + void TestR2() { + // clang-format off + // Slice at dimension start. + RunR2( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {0, 0}, + {{10.0f, 11.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + // Slice in the middle. + RunR2( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {1, 1}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 10.0f, 11.0f}, {7.0f, 8.0f, 9.0f}}); + // Slice at dimension boundaries. + RunR2( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {2, 1}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 10.0f, 11.0f}}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR2( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {2, 2}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}}); + // clang-format on + } + + template + void TestR3() { + // R3 Shape: [2, 3, 2] + // clang-format off + // Slice at dimension start. + RunR3( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {{{13.0f, 14.0f}, {15.0f, 16.0f}}, + {{17.0f, 18.0f}, {19.0f, 20.0f}}}, + {0, 0, 0}, + {{{13.0f, 14.0f}, {15.0f, 16.0f}, {5.0f, 6.0f}}, + {{17.0f, 18.0f}, {19.0f, 20.0f}, {11.0f, 12.0f}}}); + // Slice in the middle. + RunR3( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {{{13.0f}, {15.0f}}}, + {1, 1, 1}, + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 13.0f}, {11.0f, 15.0f}}}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR3( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {{{13.0f}, {15.0f}}}, + {1, 2, 1}, + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 13.0f}}}); + // clang-format on + } + + template + void RunR1(const std::vector& input_values, + const std::vector& update_values, + const std::vector slice_starts, + const std::vector& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR1(input_values); + auto update = builder.ConstantR1(update_values); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareR1(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template + void RunR2(const Array2D& input_values, + const Array2D& update_values, + const std::vector slice_starts, + const Array2D& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR2FromArray2D(input_values); + auto update = builder.ConstantR2FromArray2D(update_values); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareR2(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template + void RunR3(const Array3D& input_values, + const Array3D& update_values, + const std::vector slice_starts, + const Array3D& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR3FromArray3D(input_values); + auto update = builder.ConstantR3FromArray3D(update_values); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareR3(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + void RunR3Contiguous(std::vector operand_shape, int32 index, + int32 size) { + const int32 kSeq = operand_shape[0]; + const int32 kBatch = operand_shape[1]; + const int32 kDim = operand_shape[2]; + Array3D input_values(kSeq, kBatch, kDim); + Array3D update_values(size, kBatch, kDim); + Array3D expected_values(kSeq, kBatch, kDim); + + input_values.FillIota(0); + float val = 1000; + update_values.FillIota(val); + + // TODO(b/34128753) Expected values may vary depending on backend when + // the update wraps. According to documentation, the results are technically + // implementation specific where the update is out of bounds, and hence + // we don't really know what to pass into ComputeAndCompareR3. + expected_values.FillIota(0); + for (int i = 0; i < size; i++) { + for (int j = 0; j < kBatch; j++) { + for (int k = 0; k < kDim; k++) { + expected_values((index + i) % kSeq, j, k) = val++; + } + } + } + if (VLOG_IS_ON(1)) { + DumpArray("input", input_values); + DumpArray("update", update_values); + DumpArray("expected", expected_values); + } + + // Build dynamic slice computation. + ComputationBuilder builder(client_, TestName()); + auto starts = builder.ConstantR1({index, 0, 0}); + auto input = builder.ConstantR3FromArray3D(input_values); + auto update = builder.ConstantR3FromArray3D(update_values); + builder.DynamicUpdateSlice(input, update, starts); + + // Run computation and compare against expected values. + ComputeAndCompareR3(&builder, expected_values, {}, + ErrorSpec(0.000001)); + } + + template + void DumpArray(const string& name, const Array3D values) { + std::unique_ptr literal = + LiteralUtil::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal); + } +}; + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } + +// Tests for simple R3 case where the update is contiguous (i.e. the minor +// two dimensions are not sliced). +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { + // Single element, no wrap. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { + // Multiple element, no wrap. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); +} + +// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle +// wrapping as expected. +XLA_TEST_F(DynamicUpdateSliceTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousMultipleWrapping))) { + // Multiple element, wrapping. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); +} + +// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle +// wrapping as expected. +XLA_TEST_F(DynamicUpdateSliceTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousTooLarge))) { + // Multiple element, update size larger than operand. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { + std::vector operand_shape({3, 123, 247}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +// TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) { + std::vector operand_shape({32, 128, 1024}); + RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); +} + +void BM_DynamicSlice(int num_iters) { + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + LocalClient* client = + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + ComputationBuilder builder(client, "DynamicSlice"); + + // Create input as a constant: shape [1, 2, 3, 4] + auto input_literal = LiteralUtil::CreateR4( + {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, + {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); + auto input = builder.ConstantLiteral(*input_literal); + + // Create dynamic slice start indices as a parameter: shape [4] + auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); + auto start_indices = + builder.Parameter(0, start_indices_shape, "start_indices"); + // Add DynamicSlice op to the computatation. + builder.DynamicSlice(input, start_indices, {1, 1, 1, 1}); + auto computation = builder.Build().ConsumeValueOrDie(); + + // Initialize and transfer parameter buffer. + auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(start_indices_shape, + &allocator, 0) + .ConsumeValueOrDie(); + + auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *start_indices_literal, + buffer->mutable_buffer({}))); + + // Run some warm-up executions. + LocalExecuteOptions options; + options.set_allocator(&allocator); + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + auto result = client->ExecuteLocally(computation, {buffer.get()}, options); + ASSERT_TRUE(result.ok()); + } + + // Run benchmark. + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = client->ExecuteLocally(computation, {buffer.get()}, options); + ASSERT_TRUE(result.ok()); + } +} +BENCHMARK(BM_DynamicSlice); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc new file mode 100644 index 0000000000..8e30063085 --- /dev/null +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class FloorCeilTest : public ClientLibraryTestBase { + public: + enum Function { + kFloor, + kCeil, + }; + + // Runs a computation and comparison on expected vs f(input) + void TestR1F32(tensorflow::gtl::ArraySlice input, + tensorflow::gtl::ArraySlice expected, Function f) { + LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") + << "}"; + ComputationBuilder builder(client_, TestName()); + auto c = builder.ConstantR1(input); + if (f == kCeil) { + builder.Ceil(c); + } else { + ASSERT_EQ(kFloor, f); + builder.Floor(c); + } + ComputeAndCompareR1(&builder, expected, /*arguments=*/{}); + } + + void TestR0F32(float input, float expected, Function f) { + LOG(INFO) << "input: " << expected; + ComputationBuilder builder(client_, TestName()); + auto c = builder.ConstantR0(input); + if (f == kCeil) { + builder.Ceil(c); + } else { + ASSERT_EQ(kFloor, f); + builder.Floor(c); + } + ComputeAndCompareR0(&builder, expected, /*arguments=*/{}); + } + + const ErrorSpec error_spec_{0.0001}; + + float infinity_ = std::numeric_limits::infinity(); + float minus_infinity_ = -std::numeric_limits::infinity(); +}; + +// Interesting notes: +// * if you pass snan the CPU doesn't canonicalize it to qnan. +// * passing x86-based CPU's qnan to the GPU makes a different nan +// "7fc00000=nan=nan vs 7fffffff=nan=nan" + +XLA_TEST_F(FloorCeilTest, R1S0Floor) { TestR1F32({}, {}, kFloor); } + +TEST_F(FloorCeilTest, R1Floor) { + TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1}, + {0.0, -0.0, infinity_, minus_infinity_, 1.0, -1.0}, kFloor); +} + +TEST_F(FloorCeilTest, R1Ceil) { + TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1}, + {0.0, -0.0, infinity_, minus_infinity_, 2.0, -0.0}, kCeil); +} + +TEST_F(FloorCeilTest, R0Floor) { + TestR0F32(0.0, 0.0, kFloor); + TestR0F32(-0.0, -0.0, kFloor); + TestR0F32(infinity_, infinity_, kFloor); + TestR0F32(minus_infinity_, minus_infinity_, kFloor); + TestR0F32(1.1, 1.0, kFloor); + TestR0F32(-0.1, -1.0, kFloor); +} + +TEST_F(FloorCeilTest, R0Ceil) { + TestR0F32(0.0, 0.0, kCeil); + TestR0F32(-0.0, -0.0, kCeil); + TestR0F32(infinity_, infinity_, kCeil); + TestR0F32(minus_infinity_, minus_infinity_, kCeil); + TestR0F32(1.1, 2.0, kCeil); + TestR0F32(-0.1, -0.0, kCeil); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc new file mode 100644 index 0000000000..2835038c90 --- /dev/null +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class FmaxSimpleTest : public ClientLibraryTestBase {}; + +TEST_F(FmaxSimpleTest, FmaxTenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = builder.ConstantR1( + {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + builder.Max(x, y); + + std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc new file mode 100644 index 0000000000..7bddbfa894 --- /dev/null +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -0,0 +1,589 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::gtl::ArraySlice; + +namespace xla { +namespace { + +const int test_width = 2, test_height = 3; + +const float test_float_vals[3][test_width][test_height] = { + {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}}, + {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}}, + {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}}; + +// Test whether fusion operations are emitted with no errors and compute +// accurate outputs. +class FusionTest : public HloTestBase { + protected: + template + void TestElementwise2D(HloOpcode opcode) { + Array2D operand_data[Arity]; + for (int i = 0; i < Arity; ++i) { + new (&operand_data[i]) Array2D(test_width, test_height); + } + Array2D answer_data(test_width, test_height); + for (int i = 0; i < test_width; ++i) { + for (int j = 0; j < test_height; ++j) { + float xs[Arity]; + for (int k = 0; k < Arity; ++k) { + xs[k] = test_float_vals[k][i][j]; + operand_data[k](i, j) = xs[k]; + } + answer_data(i, j) = ComputeElementwiseAnswer(opcode, xs); + } + } + + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + + auto prim_type = primitive_util::NativeToPrimitiveType(); + + HloInstruction* hlos[4]; + for (int i = 0; i < Arity; ++i) { + hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(operand_data[i]))); + } + auto answer_shape = + ShapeUtil::MakeShape(prim_type, {test_width, test_height}); + std::unique_ptr root_hlo; + switch (Arity) { + case 1: + root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); + break; + case 2: + root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], + hlos[2]); + break; + case 3: + root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], + hlos[2], hlos[3]); + break; + default: + LOG(FATAL) << "Bad arity: " << Arity; + } + hlos[0] = builder.AddInstruction(std::move(root_hlo)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction( + ArraySlice(hlos, 0, Arity + 1), + HloInstruction::FusionKind::kLoop); + + auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); + if (primitive_util::IsFloatingPointType(prim_type)) { + LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); + } else { + LiteralTestUtil::ExpectEqual(*expected, *actual); + } + } + + private: + template + T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice xs); +}; + +template <> +float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, + ArraySlice xs) { + switch (opcode) { + case HloOpcode::kAdd: + return xs[0] + xs[1]; + case HloOpcode::kSubtract: + return xs[0] - xs[1]; + case HloOpcode::kMultiply: + return xs[0] * xs[1]; + case HloOpcode::kDivide: + return xs[0] / xs[1]; + case HloOpcode::kPower: + return powf(xs[0], xs[1]); + case HloOpcode::kMinimum: + return std::min(xs[0], xs[1]); + case HloOpcode::kMaximum: + return std::max(xs[0], xs[1]); + case HloOpcode::kClamp: + return std::min(xs[2], std::max(xs[1], xs[0])); + default: + LOG(FATAL) << "No elementwise opcode: " << opcode; + } +} + +template <> +uint8 FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, + ArraySlice xs) { + switch (opcode) { + case HloOpcode::kEq: + return xs[0] == xs[1]; + case HloOpcode::kNe: + return xs[0] != xs[1]; + case HloOpcode::kGt: + return xs[0] > xs[1]; + case HloOpcode::kLt: + return xs[0] < xs[1]; + case HloOpcode::kGe: + return xs[0] >= xs[1]; + case HloOpcode::kLe: + return xs[0] <= xs[1]; + default: + LOG(FATAL) << "No comparatory opcode: " << opcode; + } +} + +XLA_TEST_F(FusionTest, Test) { + // test expression: + // slice(select({{T, F, T}, {F, T, F}}, + // concat(transpose({{1.0}, {2.0}, {3.0}} + + // {{-1.0}, {-1.0}, {-1.0}}), + // {{1.62, 2.72, 3.14}}) + + // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), + // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); + auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); + auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); + auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); + auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); + auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); + auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, false, true}, {false, true, false}}))); + auto select11 = builder.AddInstruction( + HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), + HloOpcode::kSelect, const10, add8, const9)); + auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2})); + // CreateFusionInstruction needs the `instructions_to_fuse` argument in + // reverse topological order, so the first element in `instructions_to_fuse` + // must be the root. + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction( + {slice12, select11, const10, const9, add8, negate7, const6, concat5, + const4, reshape3, add2, const1, const0}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{0.5}, {2.72}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), + ErrorSpec(1e-4)); +} + +// Test whether we emit appropriate code for parameters of fusion instructions. +XLA_TEST_F(FusionTest, Parameter) { + // Build a computation and fuse part of it so the fusion instruction has an + // operand parameter. + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); + auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); + auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); + // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} + auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); + // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological + // order. + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), + ErrorSpec(1e-4)); +} + +XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); + // add2 = broadcast(const_vector) + const_array + // = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} + // = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}), + HloOpcode::kAdd, broadcast, const_array)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); +} + +XLA_TEST_F(FusionTest, ReshapeToScalar) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto single_element_array = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); + auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), single_element_array)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(5), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual( + *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual( + *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reshape_1by1by1_) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reshape__1by1by1) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR3({{{7}}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reshape__) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + auto reshape1 = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual( + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Transpose_2by3) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual( + *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Transpose_3by3) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, + HloInstruction::FusionKind::kLoop); + LiteralTestUtil::ExpectEqual( + *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Reverse) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); + auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( + ShapeUtil::MakeShape(S32, {3}), const0, {0})); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({3, 2, 1}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +std::unique_ptr MakeReduceTestComputation() { + auto builder = HloComputation::Builder("add"); + auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs")); + auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs)); + return builder.Build(); +} + +XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { + auto hlo_module = MakeUnique(TestName()); + + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, + hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(15), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { + auto hlo_module = MakeUnique(TestName()); + + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, + hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); + auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kNegate, reduce2)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({-15}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = MakeUnique(TestName()); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + Window window; + ASSERT_TRUE( + tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" + "size:2\n" + "stride:1\n" + "padding_low:0\n" + "padding_high:0\n" + "window_dilation:1\n" + "base_dilation:1\n" + "}\n" + "dimensions:{\n" + "size:2\n" + "stride:1\n" + "padding_low:0\n" + "padding_high:0\n" + "window_dilation:1\n" + "base_dilation:1\n" + "}\n", + &window)); + auto nested_builder = HloComputation::Builder("mul"); + { + auto x = nested_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x")); + auto y = nested_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y")); + nested_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y)); + } + auto nested_computation = + hlo_module->AddEmbeddedComputation(nested_builder.Build()); + auto reduce_window2 = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window, + nested_computation)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, + HloInstruction::FusionKind::kLoop); + + LiteralTestUtil::ExpectEqual( + *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + +XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } + +XLA_TEST_F(FusionTest, Subtract2D) { + TestElementwise2D(HloOpcode::kSubtract); +} + +XLA_TEST_F(FusionTest, Multiply2D) { + TestElementwise2D(HloOpcode::kMultiply); +} + +XLA_TEST_F(FusionTest, Divide2D) { + TestElementwise2D(HloOpcode::kDivide); +} + +XLA_TEST_F(FusionTest, Power2D) { + TestElementwise2D(HloOpcode::kPower); +} + +XLA_TEST_F(FusionTest, Minimum2D) { + TestElementwise2D(HloOpcode::kMinimum); +} + +XLA_TEST_F(FusionTest, Maximum2D) { + TestElementwise2D(HloOpcode::kMaximum); +} + +XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D(HloOpcode::kEq); } + +XLA_TEST_F(FusionTest, Inequal2D) { + TestElementwise2D(HloOpcode::kNe); +} + +XLA_TEST_F(FusionTest, Greater2D) { + TestElementwise2D(HloOpcode::kGt); +} + +XLA_TEST_F(FusionTest, Lesser2D) { + TestElementwise2D(HloOpcode::kLt); +} + +XLA_TEST_F(FusionTest, GreaterOrEqual2D) { + TestElementwise2D(HloOpcode::kGe); +} + +XLA_TEST_F(FusionTest, LesserOrEqual2D) { + TestElementwise2D(HloOpcode::kLe); +} + +XLA_TEST_F(FusionTest, Clamp2D) { + TestElementwise2D(HloOpcode::kClamp); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc new file mode 100644 index 0000000000..872188de81 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct HloTestBase::EigenThreadPoolWrapper { + std::unique_ptr pool; + std::unique_ptr device; +}; + +HloTestBase::HloTestBase() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) { + test_hlo_dumper_ = [](const HloModule& module, const string& label) { + legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags(); + if (flags->xla_hlo_test_generate_hlo_graph) { + const bool show_addresses = true; + const bool show_layouts = true; + hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, + show_addresses, show_layouts); + } + }; + VLOG(1) << "executing on platform " << backend_->platform()->Name(); +} + +HloTestBase::~HloTestBase() { + // Deallocate all the memory allocated during the tests. + for (auto& allocation : allocations_) { + backend_->default_stream_executor()->Deallocate(&allocation); + } +} + +StatusOr HloTestBase::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape) { + auto module_config = MakeUnique( + MakeProgramShape(module->entry_computation())); + return Execute(std::move(module), std::move(module_config), arguments, + result_shape); +} + +StatusOr HloTestBase::Execute( + std::unique_ptr hlo_module, + std::unique_ptr module_config, + tensorflow::gtl::ArraySlice arguments, + Shape* result_shape) { + VLOG(3) << "module_config layout " + << LayoutUtil::HumanString(module_config->entry_computation_layout() + .result_layout() + .layout()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend_->compiler()->Compile(std::move(hlo_module), + std::move(module_config), test_hlo_dumper_, + backend_->default_stream_executor())); + + se::Stream stream(backend_->default_stream_executor()); + stream.Init(); + + ExecutableRunOptions run_options; + run_options.set_stream(&stream); + run_options.set_allocator(backend_->memory_allocator()); + run_options.set_inter_op_thread_pool(backend_->inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + backend_->eigen_intra_op_thread_pool_device()); + + HloExecutionProfile hlo_execution_profile; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result, + executable->ExecuteOnStream(&run_options, arguments, + &hlo_execution_profile)); + TF_RET_CHECK(stream.BlockHostUntilDone()); + + allocations_.push_back(result); + + *result_shape = executable->result_shape(); + + if (ShapeUtil::IsTuple(*result_shape)) { + // We must record element buffers of tuples as well to avoid leaks. + DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + backend_->transfer_manager()->ShallowCopyTupleFromDevice( + backend_->default_stream_executor(), result, *result_shape)); + + // A tuple may contain the same buffer in more than one element. Keep track + // of the buffers already added to avoid duplicates in allocations_. + std::set added_opaques; + for (auto element_buffer : element_buffers) { + if (added_opaques.count(element_buffer.opaque()) == 0) { + added_opaques.insert(element_buffer.opaque()); + allocations_.push_back(element_buffer); + } + } + } + + return result; +} + +se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { + // Allocate memory on the device using the stream executor. + int64 allocation_size = + backend_->transfer_manager()->GetByteSizeRequirement(literal.shape()); + se::DeviceMemoryBase allocation = + backend_->default_stream_executor()->AllocateArray( + allocation_size); + allocations_.push_back(allocation); + + TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice( + backend_->default_stream_executor(), literal, &allocation)); + + return allocation; +} + +std::unique_ptr HloTestBase::TransferFromDevice( + const Shape& shape, se::DeviceMemoryBase device_base) { + auto literal = MakeUnique(); + TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralFromDevice( + backend_->default_stream_executor(), device_base, shape, shape, + literal.get())); + return literal; +} + +std::unique_ptr HloTestBase::ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + Shape result_shape; + se::DeviceMemoryBase device_base = + Execute(std::move(module), arguments, &result_shape).ValueOrDie(); + return TransferFromDevice(result_shape, device_base); +} + +std::unique_ptr HloTestBase::ExecuteAndTransfer( + std::unique_ptr module, + std::unique_ptr module_config, + tensorflow::gtl::ArraySlice arguments) { + Shape result_shape; + se::DeviceMemoryBase device_base = + Execute(std::move(module), std::move(module_config), arguments, + &result_shape) + .ValueOrDie(); + return TransferFromDevice(result_shape, device_base); +} + +ProgramShape HloTestBase::MakeProgramShape(HloComputation* computation) { + ProgramShape program_shape; + for (int64 i = 0; i < computation->num_parameters(); ++i) { + *program_shape.add_parameters() = + computation->parameter_instruction(i)->shape(); + } + *program_shape.mutable_result() = computation->root_instruction()->shape(); + return program_shape; +} + +string HloTestBase::TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h new file mode 100644 index 0000000000..fa88c76899 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +// A base class for tests which build and run HLO code. This is a lower level of +// abstraction than using the client interface and enables, for one, explicitly +// building a graph of HLO instructions to run. +class HloTestBase : public ::testing::Test { + protected: + struct EigenThreadPoolWrapper; + HloTestBase(); + + ~HloTestBase() override; + + // Executes the given module and returns a global data handle. + StatusOr Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape); + + // Variation of Execute which takes a custom module_config instead of creating + // a default one. + StatusOr Execute( + std::unique_ptr module, + std::unique_ptr module_config, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape); + + // Transfers the given literal to the device and returns the data handle. + perftools::gputools::DeviceMemoryBase TransferToDevice( + const Literal& literal); + + // Transfers the array refered to by the given handle from the device and + // returns as a Literal. + std::unique_ptr TransferFromDevice( + const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + + // Executes the given module and return the result as a Literal. + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments); + + // Variation of ExecuteAndTransfer which takes a custom module_config instead + // of creating a default one. + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + std::unique_ptr module_config, + tensorflow::gtl::ArraySlice + arguments); + + // Utility function which creates a ProgramShape for a given computation. + ProgramShape MakeProgramShape(HloComputation* computation); + + string TestName() const; + + std::unique_ptr backend_; + + Compiler::HloDumper test_hlo_dumper_; + + // This vector contains handles of all the device memory allocations performed + // by the test. These are deallocated on destruction of the test object. + std::vector allocations_; + + ErrorSpec error_spec_{0.0001}; + + std::unique_ptr thread_pool_wrapper_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/inprocess_service_test.cc b/tensorflow/compiler/xla/tests/inprocess_service_test.cc new file mode 100644 index 0000000000..9909f041de --- /dev/null +++ b/tensorflow/compiler/xla/tests/inprocess_service_test.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Tests which exercise the "InProcess" methods of xla::Client. The +// "InProcess" methods require that the client and server share the same +// process. +class InProcessServiceTest : public ClientLibraryTestBase { + protected: + std::unique_ptr ExecuteR2F32Constant( + std::initializer_list> values, + tensorflow::gtl::ArraySlice minor_to_major) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR2(values); + auto computation = builder.Build().ConsumeValueOrDie(); + CHECK_EQ(2, minor_to_major.size()); + Shape shape_with_layout = ShapeUtil::MakeShapeWithLayout( + F32, + /*dimensions=*/{static_cast(values.size()), + static_cast(values.begin()->size())}, + minor_to_major); + return client_ + ->Execute(computation, {}, &shape_with_layout, + /*execution_profile=*/nullptr) + .ConsumeValueOrDie(); + } + + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(InProcessServiceTest, TransferFromServer) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1({1, 42, 5}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto handle = client_->Execute(computation, {}).ConsumeValueOrDie(); + + std::vector result(3, 0); + ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); + EXPECT_MATCH(result, testing::VectorMatcher({1, 42, 5})); +} + +XLA_TEST_F(InProcessServiceTest, TransferToServer) { + std::vector input{1.0f, 2.0f, -42.0f}; + Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto data_handle = client_->TransferToServerInProcess(shape, input.data()) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto param = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "param"); + builder.Add(param, param); + + ComputeAndCompareR1(&builder, {2.0f, 4.0f, -84.0f}, + {data_handle.get()}, error_spec_); +} + +// TODO(b/28506710): This test case seems not to test inprocess +// methods. +TEST_F(InProcessServiceTest, GetShape) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR1({1, 42, 5}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto handle = client_->Execute(computation, {}).ConsumeValueOrDie(); + + Shape shape = client_->GetShape(*handle).ConsumeValueOrDie(); + ASSERT_EQ(S32, shape.element_type()); + ASSERT_EQ(1, ShapeUtil::Rank(shape)); + ASSERT_EQ(3, shape.dimensions(0)); +} + +XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayRowMajor) { + std::vector input{1.0f, 2.0f, 3.0f, 4.0f}; + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + shape.clear_layout(); + *shape.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + auto handle = client_->TransferToServerInProcess(shape, input.data()) + .ConsumeValueOrDie(); + + Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned)); +} + +XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayColMajor) { + std::vector input{1.0f, 2.0f, 3.0f, 4.0f}; + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + shape.clear_layout(); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + auto handle = client_->TransferToServerInProcess(shape, input.data()) + .ConsumeValueOrDie(); + + Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned)); +} + +TEST_F(InProcessServiceTest, TransferToServerNoLayout) { + std::vector input{1.0f, 2.0f, -42.0f}; + Shape shape = ShapeUtil::MakeShape(F32, {3}); + shape.clear_layout(); + auto transfer_status = + client_->TransferToServerInProcess(shape, input.data()); + ASSERT_EQ(transfer_status.status().code(), + tensorflow::error::INVALID_ARGUMENT); +} + +XLA_TEST_F(InProcessServiceTest, ExecuteRowMajor) { + auto handle = + ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0}); + + std::vector result(4, 0.0); + Shape shape; + ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); + + EXPECT_MATCH(result, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); +} + +XLA_TEST_F(InProcessServiceTest, ExecuteColumnMajor) { + auto handle = + ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{0, 1}); + + std::vector result(4, 0); + Shape shape; + ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); + + EXPECT_MATCH(result, testing::VectorMatcher({1.0, 3.0, 2.0, 4.0})); +} + +XLA_TEST_F(InProcessServiceTest, ExecuteAndReuseDifferentLayouts) { + // Create arrays on the server which have different layouts. Verify the + // computation still produces the correct results. + auto handle_rowmaj = + ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0}); + + auto handle_colmaj = ExecuteR2F32Constant({{10.0, 20.0}, {30.0, 40.0}}, + /*minor_to_major=*/{0, 1}); + + ComputationBuilder builder(client_, TestName()); + auto param0 = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); + auto param1 = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "param1"); + builder.Add(param0, param1); + + Array2D expected({{11.0, 22.0}, {33.0, 44.0}}); + ComputeAndCompareR2(&builder, expected, + {handle_rowmaj.get(), handle_colmaj.get()}, + error_spec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc new file mode 100644 index 0000000000..f7bbc0f38b --- /dev/null +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -0,0 +1,566 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/literal_test_util.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, + const Shape& actual) { + ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); + ASSERT_EQ(expected.element_type(), actual.element_type()) + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + for (int i = 0; i < expected.dimensions_size(); ++i) { + ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } + ASSERT_EQ(expected.tuple_shapes_size(), actual.tuple_shapes_size()); + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + } +} + +/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( + const Shape& expected, const Shape& actual) { + ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); +} + +namespace { + +string Hostname() { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); +} + +// Helper function for comparing a floating point type, FloatT, bitwise equal +// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT +// -- on miscompare, a nice error message is given in the AssertionFailure. +template +testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { + auto ulhs = tensorflow::bit_cast(lhs); + auto urhs = tensorflow::bit_cast(rhs); + if (ulhs != urhs) { + return testing::AssertionFailure() << tensorflow::strings::Printf( + "floating values are not bitwise-equal; and equality testing " + "was requested: %s=%g=%a vs %s=%g=%a", + tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) + .c_str(), + lhs, lhs, + tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) + .c_str(), + rhs, rhs); + } + return testing::AssertionSuccess(); +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { + if (lhs == rhs) { + return testing::AssertionSuccess(); + } + ::testing::Message msg; + msg << "Expected equality of these values:"; + msg << "\n " << lhs; + msg << "\n " << rhs; + + return testing::AssertionFailure() << msg; +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +testing::AssertionResult CompareEqual(float lhs, float rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +testing::AssertionResult CompareEqual(double lhs, double rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} + +// A recursive function which iterates through every index of expected and +// actual literal and compares their values elementwise. Returns true if all +// elements are equal. +template +bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, + tensorflow::gtl::MutableArraySlice multi_index, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + NativeT expected_value = LiteralUtil::Get(expected, multi_index); + NativeT actual_value = LiteralUtil::Get(actual, multi_index); + testing::AssertionResult result = + CompareEqual(expected_value, actual_value); + return result; // Defines implicit coersion to bool. + } + + bool all_match = true; + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index[dimension] = i; + all_match = all_match && ExpectLiteralsEqual( + expected, actual, multi_index, dimension + 1); + } + return all_match; +} + +} // namespace + +/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, + const Literal& actual) { + EXPECT_TRUE(Equal(expected, actual)) << "expected:\n" + << LiteralUtil::ToString(expected) + << "\n\tvs actual:\n" + << LiteralUtil::ToString(actual); +} + +/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, + const Literal& actual) { + EXPECT_FALSE(Equal(expected, actual)); +} + +/* static */ testing::AssertionResult LiteralTestUtil::Equal( + const Literal& expected, const Literal& actual) { + VLOG(1) << "expected: " << LiteralUtil::ToString(expected); + VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + + AssertEqualShapes(expected.shape(), actual.shape()); + std::vector multi_index(expected.shape().dimensions_size(), 0); + bool match = false; + switch (expected.shape().element_type()) { + case PRED: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case U8: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case S32: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case S64: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case U32: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case U64: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case F32: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case F64: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; + case TUPLE: { + bool tuple_match = true; + for (int i = 0; i < actual.tuple_literals_size(); ++i) { + auto result = + Equal(expected.tuple_literals(i), actual.tuple_literals(i)); + tuple_match = tuple_match ? !!result : false; + } + match = tuple_match; + break; + } + default: + LOG(FATAL) + << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " + << PrimitiveType_Name(expected.shape().element_type()); + } + testing::AssertionResult result = testing::AssertionSuccess(); + if (!match) { + result = testing::AssertionFailure() + << "expected: " << LiteralUtil::ToString(expected) + << "\nactual: " << LiteralUtil::ToString(actual); + VLOG(1) << result.message(); + } + return result; +} + +/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, + const Literal& actual) { + VLOG(1) << "expected: " << LiteralUtil::ToString(expected); + VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + + ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); + ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); + AssertEqualShapes(expected.shape(), actual.shape()); + for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { + const auto& expected_element = expected.tuple_literals(i); + const auto& actual_element = actual.tuple_literals(i); + if (ShapeUtil::IsTuple(expected_element.shape())) { + ExpectEqualTuple(expected_element, actual_element); + } else { + ExpectEqual(expected_element, actual_element); + } + } +} + +namespace { + +// Helper class for comparing floating-point literals within an error bound. +class NearComparator { + public: + explicit NearComparator(ErrorSpec error) : error_(error) {} + + // Compares the two literals elementwise. EXPECTs each pair of elements to be + // within the error bound. Emits useful log messages and dumps literals to + // temporary files on failure. Returns true if literals match. + bool ExpectNear(const Literal& expected, const Literal& actual) { + VLOG(1) << "expected: " << LiteralUtil::ToString(expected); + VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + + LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); + + // Set up members used during the comparison. + num_miscompares_ = 0; + abs_diff_sum_ = 0.0; + abs_expected_sum_ = 0.0; + abs_diff_miscompare_sum_ = 0.0; + abs_expected_miscompare_sum_ = 0.0; + max_rel_err_ = 0.0; + max_abs_err_ = 0.0; + *miscompares_.mutable_shape() = + ShapeUtil::ChangeElementType(actual.shape(), PRED); + miscompares_.mutable_preds()->Resize( + ShapeUtil::ElementsIn(miscompares_.shape()), false); + multi_index_.resize(expected.shape().dimensions_size(), 0); + + switch (expected.shape().element_type()) { + case F32: + ExpectLiteralsNear(expected, actual, 0); + break; + case F64: + ExpectLiteralsNear(expected, actual, 0); + break; + default: + LOG(FATAL) << "Unsupported primitive type in near comparator: " + << PrimitiveType_Name(expected.shape().element_type()) + << ". Must be floating-point type."; + } + + if (num_miscompares_ > 0) { + if (!VLOG_IS_ON(1)) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) + << " " << LiteralUtil::ToString(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) + << " " << LiteralUtil::ToString(actual); + } + EXPECT_TRUE(num_miscompares_ == 0) + << "\nmax relative mismatch at index " + << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_) + << "\nmaximum relative error " << max_rel_err_ + << "\nmax absolute mismatch at index " + << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_) + << "\nmaximum absolute error " << max_abs_err_ + << "\nfirst mismatch at index " + << LiteralTestUtil::MultiIndexAsString(first_multi_index_) + << "\nlast mismatch at index " + << LiteralTestUtil::MultiIndexAsString(last_multi_index_) + << "\ntotal absolute error " << abs_diff_sum_ + << "\ntotal absolute error of miscompares " + << abs_diff_miscompare_sum_ << "\ntotal relative error " + << (abs_diff_sum_ / abs_expected_sum_) + << "\ntotal relative error of miscompares " + << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_) + << "\nfailure count " << num_miscompares_; + + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(miscompares_, "miscompares"); + } + return num_miscompares_ == 0; + } + + private: + // EXPECTs that the two given scalar values are within the error bound. Keeps + // track of how many mismatches have occured to keep the size of the output + // manageable. + template + bool ExpectValuesNear(NativeT expected, NativeT actual) { + if (expected == actual) { + return true; + } + + float abs_diff = std::abs(actual - expected); + float rel_err = abs_diff / std::abs(expected); + abs_diff_sum_ += abs_diff; + abs_expected_sum_ += std::abs(expected); + if (rel_err > max_rel_err_) { + max_rel_err_ = rel_err; + max_rel_multi_index_ = multi_index_; + } + if (abs_diff > max_abs_err_) { + max_abs_err_ = abs_diff; + max_abs_multi_index_ = multi_index_; + } + VLOG(10) << tensorflow::strings::Printf( + "index %s abs_diff %f rel_err %f", + LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff, + rel_err); + bool nan_mismatch = std::isnan(actual) != std::isnan(expected); + bool mismatch = + (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); + if (mismatch) { + abs_diff_miscompare_sum_ += abs_diff; + abs_expected_miscompare_sum_ += std::abs(expected); + const int64 kMaxFailures = 2; + if (num_miscompares_ < kMaxFailures) { + EXPECT_NEAR(expected, actual, error_.abs) + << "mismatch at index " + << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff " + << abs_diff << " rel err " << rel_err << " failure #" + << num_miscompares_; + } else if (num_miscompares_ == kMaxFailures) { + LOG(ERROR) + << "reached max 'loud' failure count; silently proceeding..."; + } + if (num_miscompares_ == 0) { + first_multi_index_ = multi_index_; + } + num_miscompares_++; + last_multi_index_ = multi_index_; + } + return !mismatch; + } + + // Recursive function which compares the two given literals elementwise. + template + void ExpectLiteralsNear(const Literal& expected, const Literal& actual, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + bool near = + ExpectValuesNear(LiteralUtil::Get(expected, multi_index_), + LiteralUtil::Get(actual, multi_index_)); + LiteralUtil::Set(&miscompares_, multi_index_, !near); + } else { + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index_[dimension] = i; + ExpectLiteralsNear(expected, actual, dimension + 1); + } + } + } + + // Writes the given literal to a file in the test temporary directory. + void WriteLiteralToTempFile(const Literal& literal, const string& name) { + int64 now_usec = tensorflow::Env::Default()->NowMicros(); + string filename = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), + tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), + now_usec, name.c_str())); + TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), + filename, literal)); + LOG(ERROR) << "wrote to " << name << " file: " << filename; + } + + ErrorSpec error_; + + // Number of element miscomparisons encountered so far. + int64 num_miscompares_; + + // A Literal containing which elements did not match in the expected and + // actual literals. miscompares_ contains PREDs and is of the same sizes as + // the comparison literals. + Literal miscompares_; + + // A multidimensional index used when performing the recursive comparison. + std::vector multi_index_; + + // Aggregated Statistics on input. + double abs_diff_sum_; + double abs_expected_sum_; + double abs_diff_miscompare_sum_; + double abs_expected_miscompare_sum_; + float max_rel_err_; + float max_abs_err_; + std::vector first_multi_index_; + std::vector last_multi_index_; + std::vector max_rel_multi_index_; + std::vector max_abs_multi_index_; +}; + +} // namespace + +/* static */ testing::AssertionResult LiteralTestUtil::Near( + const Literal& expected, const Literal& actual, const ErrorSpec& error) { + NearComparator comparator(error); + return comparator.ExpectNear(expected, actual) + ? testing::AssertionSuccess() + : testing::AssertionFailure() << "values were not near"; +} + +/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, + const Literal& actual, + const ErrorSpec& error) { + EXPECT_TRUE(Near(expected, actual, error)); +} + +/* static */ testing::AssertionResult LiteralTestUtil::NearTuple( + const Literal& expected, const Literal& actual, const ErrorSpec& error) { + VLOG(1) << "expected: " << LiteralUtil::ToString(expected); + VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + + if (!ShapeUtil::IsTuple(expected.shape()) || + !ShapeUtil::IsTuple(actual.shape())) { + return testing::AssertionFailure() + << "tuples expected expected shape = " + << expected.shape().ShortDebugString() + << " actual shape = " << actual.shape().ShortDebugString(); + } + AssertEqualShapes(expected.shape(), actual.shape()); + for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { + const auto& expected_element = expected.tuple_literals(i); + const auto& actual_element = actual.tuple_literals(i); + if (ShapeUtil::IsTuple(expected_element.shape())) { + auto ret = NearTuple(expected_element, actual_element, error); + if (!ret) { + return ret; + } + } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { + auto ret = Near(expected_element, actual_element, error); + if (!ret) { + return ret; + } + } else { + auto ret = Equal(expected_element, actual_element); + if (!ret) { + return ret; + } + } + } + + return testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, + const Literal& actual, + const ErrorSpec& error) { + EXPECT_TRUE(NearTuple(expected, actual, error)); +} + +/* static */ string LiteralTestUtil::MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index) { + return tensorflow::strings::StrCat( + "{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +/* static */ std::unique_ptr LiteralTestUtil::Reshape( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + + auto new_literal = MakeUnique(); + *new_literal->mutable_shape() = + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Allocate space in the new literal. + LiteralUtil::Reserve(ShapeUtil::ElementsIn(literal.shape()), + new_literal.get()); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case U8: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case U32: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case S32: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case U64: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case S64: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case F32: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + case F64: + LiteralUtil::Set( + new_literal.get(), to_multi_index, + LiteralUtil::Get(literal, from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h new file mode 100644 index 0000000000..85656a53e4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -0,0 +1,274 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Structure describing permissible absolute and relative error bounds. +struct ErrorSpec { + explicit ErrorSpec(float aabs, float arel = 0) : abs(aabs), rel(arel) {} + + float abs; // Absolute error bound. + float rel; // Relative error bound. +}; + +// Utility class for making expectations/assertions related to XLA literals. +class LiteralTestUtil { + public: + // Asserts that the given shapes have the same rank, dimension sizes, and + // primitive types. + static void AssertEqualShapes(const Shape& expected, const Shape& actual); + + // Asserts that the provided shapes are equal as defined in AssertEqualShapes + // and that they have the same layout. + static void AssertEqualShapesAndLayouts(const Shape& expected, + const Shape& actual); + + // Asserts that the expected and actual literals are (bitwise) equal for all + // elements in the literal. Also, asserts that the rank, dimensions sizes, and + // primitive type are equal. + static testing::AssertionResult Equal( + const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + + // Expects that expected and actual are Equal. + static void ExpectEqual(const Literal& expected, const Literal& actual); + + // Expects that expected and actual are Not Equal. + static void ExpectNotEqual(const Literal& expected, const Literal& actual); + + // Asserts the given literal are (bitwise) equal to given expected values. + template + static void ExpectR0Equal(NativeT expected, const Literal& actual); + template + static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, + const Literal& actual); + template + static void ExpectR2Equal( + std::initializer_list> expected, + const Literal& actual); + template + static void ExpectR3Equal( + std::initializer_list< + std::initializer_list>> + expected, + const Literal& actual); + + // Asserts the given literal are (bitwise) equal to given array. + template + static void ExpectR2EqualArray2D(const Array2D& expected, + const Literal& actual); + template + static void ExpectR3EqualArray3D(const Array3D& expected, + const Literal& actual); + template + static void ExpectR4EqualArray4D(const Array4D& expected, + const Literal& actual); + + // Expects that the values of the elements in the expected and actual tuples + // are equal. Tuples are matched recursively. + static void ExpectEqualTuple(const Literal& expected, const Literal& actual); + + // Asserts that the expected and actual literals are within the given error + // bound for all elements. Also, asserts that the rank, dimensions sizes, and + // bounds are equivalent. Only supported for floating point values. + static testing::AssertionResult Near( + const Literal& expected, const Literal& actual, + const ErrorSpec& error) TF_MUST_USE_RESULT; + + // Expects expected and actual to be Near with the given error. + static void ExpectNear(const Literal& expected, const Literal& actual, + const ErrorSpec& error); + + // Asserts the given literal are within the given error bound of the given + // expected values. Only supported for floating point values. + template + static void ExpectR0Near(NativeT expected, const Literal& actual, + const ErrorSpec& error); + template + static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, + const Literal& actual, const ErrorSpec& error); + template + static void ExpectR2Near( + std::initializer_list> expected, + const Literal& actual, const ErrorSpec& error); + template + static void ExpectR3Near( + std::initializer_list< + std::initializer_list>> + expected, + const Literal& actual, const ErrorSpec& error); + + // Asserts the given literal are within the given error bound to the given + // array. Only supported for floating point values. + template + static void ExpectR2NearArray2D(const Array2D& expected, + const Literal& actual, + const ErrorSpec& error); + template + static void ExpectR3NearArray3D(const Array3D& expected, + const Literal& actual, + const ErrorSpec& error); + template + static void ExpectR4NearArray4D(const Array4D& expected, + const Literal& actual, + const ErrorSpec& error); + + // Returns whether the values of the elements in the expected and actual + // tuples are within the given error bound. Tuples are matched recursively. + // If the elements of the tuple are not floating-point types, the error spec + // is ignored and exact equality is checked. + static testing::AssertionResult NearTuple( + const Literal& expected, const Literal& actual, + const ErrorSpec& error) TF_MUST_USE_RESULT; + + // Expects that the expected and actual values are near. + static void ExpectNearTuple(const Literal& expected, const Literal& actual, + const ErrorSpec& error); + + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will + // be returned for a 2-dimensional index with dimension 0 index equal to 7, + // dimension 1 equal to 8. + static string MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index); + + // Creates a literal with a new shape with the given new dimensions using the + // data in the given input literal. For reshaping purposes the (flat) data + // buffer of the input literal is assumed to have the given minor_to_major + // layout order. + static std::unique_ptr Reshape( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const Literal& literal); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); +}; + +template +/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, + const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR0(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR1Equal( + tensorflow::gtl::ArraySlice expected, const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR1(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR2Equal( + std::initializer_list> expected, + const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR2(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR3Equal( + std::initializer_list>> + expected, + const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR3(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR2EqualArray2D( + const Array2D& expected, const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR2FromArray2D(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR3EqualArray3D( + const Array3D& expected, const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR3FromArray3D(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR4EqualArray4D( + const Array4D& expected, const Literal& actual) { + ExpectEqual(*LiteralUtil::CreateR4FromArray4D(expected), actual); +} + +template +/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, + const Literal& actual, + const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR0(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR1Near( + tensorflow::gtl::ArraySlice expected, const Literal& actual, + const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR1(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR2Near( + std::initializer_list> expected, + const Literal& actual, const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR2(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR3Near( + std::initializer_list>> + expected, + const Literal& actual, const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR3(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR2NearArray2D( + const Array2D& expected, const Literal& actual, + const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR2FromArray2D(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR3NearArray3D( + const Array3D& expected, const Literal& actual, + const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR3FromArray3D(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR4NearArray4D( + const Array4D& expected, const Literal& actual, + const ErrorSpec& error) { + ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc new file mode 100644 index 0000000000..fdec11c0e9 --- /dev/null +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that our utility functions for dealing with literals are correctly +// implemented. + +#include "tensorflow/compiler/xla/tests/literal_test_util.h" + +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { + std::unique_ptr literal = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), + }); + LiteralTestUtil::ExpectEqual(*literal, *literal); +} + +TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { + // Implementation note: we have to use a death test here, because you can't + // un-fail an assertion failure. The CHECK-failure is death, so we can make a + // death assertion. + auto unequal_things_are_equal = [] { + std::unique_ptr lhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), + }); + std::unique_ptr rhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(64).get(), + LiteralUtil::CreateR0(42).get(), + }); + CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; + }; + ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal"); +} + +TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { + auto dummy_lambda = [] { + auto two = LiteralUtil::CreateR0(2); + auto four = LiteralUtil::CreateR0(4); + ErrorSpec error(0.001); + CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; + }; + + tensorflow::Env* env = tensorflow::Env::Default(); + string pattern = + tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/tempfile-*"); + std::vector files; + TF_CHECK_OK(env->GetMatchingPaths(pattern, &files)); + for (const auto& f : files) { + TF_CHECK_OK(env->DeleteFile(f)) << f; + } + + ASSERT_DEATH(dummy_lambda(), "two is not near four"); + + // Now check we wrote temporary files to the temporary directory that we can + // read. + std::vector results; + TF_CHECK_OK(env->GetMatchingPaths(pattern, &results)); + + LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; + EXPECT_EQ(3, results.size()); + for (const string& result : results) { + Literal literal; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, + &literal)); + if (result.find("expected") != string::npos) { + EXPECT_EQ("2", LiteralUtil::ToString(literal)); + } else if (result.find("actual") != string::npos) { + EXPECT_EQ("4", LiteralUtil::ToString(literal)); + } else if (result.find("miscompares") != string::npos) { + EXPECT_EQ("true", LiteralUtil::ToString(literal)); + } else { + FAIL() << "unknown file in temporary directory: " << result; + } + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc new file mode 100644 index 0000000000..591fff338c --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/test.h" + +class LocalClientAotTest : public ::testing::Test {}; + +// This is a compiled XLA computation which calls SumStructElements, and then +// doubles the result. +extern "C" void SumAndDouble(float* out, xla::ExecutableRunOptions* options, + void** parameters, void** temporary_buffers); + +// Just some structs with some arbitrary fields used to test the OPAQUE type. +struct OpaqueData { + int field1 : 15; + int field2 : 14; + int field3 : 3; +}; + +// This is the implementation of a custom op which will be called by +// SumAndDouble. +extern "C" void SumStructElements(float* out, void** parameters) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(parameters, sizeof(OpaqueData*)); + const auto* opaque_data = static_cast(parameters[0]); + *out = opaque_data->field1 + opaque_data->field2 + opaque_data->field3; +} + +TEST_F(LocalClientAotTest, Constant) { + xla::ExecutableRunOptions run_options; + OpaqueData opaque_data{100, 20, 3}; + void* parameters[] = {&opaque_data}; + float out = 0; + float tmp = 0; + void* temporary_buffers[] = {&out, &tmp, nullptr}; + SumAndDouble(&out, &run_options, parameters, temporary_buffers); + EXPECT_EQ(out, 246.0f); + + opaque_data = {1, 2, 3}; + SumAndDouble(&out, &run_options, parameters, temporary_buffers); + EXPECT_EQ(out, 12.0f); +} diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc new file mode 100644 index 0000000000..50e5dec0f6 --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This program compiles an XLA program which computes 123 and writes the +// resulting object file to stdout. + +#include +#include + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +using xla::string; + +xla::Computation Doubler(xla::Client* client) { + xla::ComputationBuilder builder(client, "doubler"); + auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); + auto x = builder.Parameter(0, r0f32, "x"); + builder.Mul(x, builder.ConstantR0(2.0)); + return std::move(builder.Build().ValueOrDie()); +} + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + auto client = xla::ClientLibrary::LocalClientOrDie(); + + xla::ComputationBuilder builder(client, "aot_test_helper"); + auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); + auto opaque_param = builder.Parameter(0, opaque_shape, "x"); + auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); + auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); + builder.Call(Doubler(client), {sum}); + + if (argc != 2) { + LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; + } + + string triple_string; + string target_cpu = argv[1]; + if (target_cpu == "k8") { + triple_string = "x86_64-none-linux-gnu"; + } else if (target_cpu == "darwin") { + triple_string = "x86_64-apple-macosx"; + } else if (target_cpu == "arm") { + triple_string = "aarch64-none-linux-gnu"; + } else if (target_cpu == "ppc") { + triple_string = "powerpc64le-unknown-linux-gnu"; + } else if (target_cpu == "local") { + triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple()); + } else { + LOG(FATAL) << "unsupported TARGET_CPU: " << target_cpu; + } + + llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); + + xla::cpu::CpuAotCompilationOptions options( + triple_string, + /*cpu_name=*/"", /*features=*/"", "SumAndDouble", + xla::cpu::CpuAotCompilationOptions::RelocationModel::Static); + auto result = xla::unique_ptr_static_cast( + client + ->CompileAheadOfTime(builder.Build().ValueOrDie(), + /*argument_layouts=*/{&opaque_shape}, r0f32, + options) + .ConsumeValueOrDie()); + // We should have two buffers, one for the result and one temporary buffer, + // and both should be float-sized. It's lame to hard-code this, but we need + // local_client_aot_test.cc to be able to easily invoke the function. + CHECK_EQ(result->result_buffer_index(), 0); + CHECK_EQ(result->buffer_sizes().size(), 3); + CHECK_EQ(result->buffer_sizes()[0], sizeof(float)); // result buffer + CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // temp buffer + CHECK_EQ(result->buffer_sizes()[2], -1); + if (triple.isOSBinFormatELF()) { + // Check the ELF magic. + CHECK_EQ(result->object_file_data()[0], 0x7F); + CHECK_EQ(result->object_file_data()[1], 'E'); + CHECK_EQ(result->object_file_data()[2], 'L'); + CHECK_EQ(result->object_file_data()[3], 'F'); + // Check the ELF class. + CHECK_EQ(result->object_file_data()[4], triple.isArch32Bit() ? 1 : 2); + // Check the ELF endianness: it should be little. + CHECK_EQ(result->object_file_data()[5], triple.isLittleEndian() ? 1 : 2); + // Check the ELF version: it should be 1. + CHECK_EQ(result->object_file_data()[6], 1); + } + + const std::vector& object_file_data = result->object_file_data(); + std::cout.write(object_file_data.data(), object_file_data.size()); + + return 0; +} diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc new file mode 100644 index 0000000000..5c32ed8895 --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" + +#include + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +/* static */ TestAllocator* LocalClientTestBase::allocator_; + +StatusOr TestAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; + { + tensorflow::mutex_lock lock(count_mutex_); + allocation_count_++; + device_allocation_count_[device_ordinal]++; + } + return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size); +} + +tensorflow::Status TestAllocator::Deallocate( + int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) { + VLOG(2) << "Deallocate(" << device_ordinal << ")"; + { + tensorflow::mutex_lock lock(count_mutex_); + deallocation_count_++; + device_deallocation_count_[device_ordinal]++; + } + return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem); +} + +int64 TestAllocator::allocation_count() const { + tensorflow::mutex_lock lock(count_mutex_); + return allocation_count_; +} + +int64 TestAllocator::allocation_count(int device_ordinal) const { + tensorflow::mutex_lock lock(count_mutex_); + auto it = device_allocation_count_.find(device_ordinal); + if (it == device_allocation_count_.end()) { + return 0; + } else { + return it->second; + } +} + +int64 TestAllocator::deallocation_count() const { + tensorflow::mutex_lock lock(count_mutex_); + return deallocation_count_; +} + +int64 TestAllocator::deallocation_count(int device_ordinal) const { + tensorflow::mutex_lock lock(count_mutex_); + auto it = device_deallocation_count_.find(device_ordinal); + if (it == device_deallocation_count_.end()) { + return 0; + } else { + return it->second; + } +} + +/* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator( + perftools::gputools::Platform* platform) { + if (allocator_ == nullptr) { + allocator_ = new TestAllocator( + platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie() + : platform); + } + return allocator_; +} + +LocalClientTestBase::LocalClientTestBase( + perftools::gputools::Platform* platform) + : local_client_( + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) { + stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform()) + .ValueOrDie()[local_client_->default_device_ordinal()]; + transfer_manager_ = + TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie(); +} + +std::unique_ptr +LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) { + return LiteralToScopedShapedBuffer(literal, + local_client_->default_device_ordinal()); +} + +std::unique_ptr +LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal, + int device_ordinal) { + CHECK(!ShapeUtil::IsTuple(literal.shape())); + auto scoped_buffer = + ScopedShapedBuffer::MakeScopedShapedBuffer( + literal.shape(), GetOrCreateAllocator(local_client_->platform()), + device_ordinal) + .ConsumeValueOrDie(); + // The creation of the scoped shaped buffer should allocate the buffer. + CHECK(!scoped_buffer->buffer(/*index=*/{}).is_null() || + ShapeUtil::HasZeroElements(literal.shape())); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, literal, scoped_buffer->mutable_buffer(/*index=*/{}))); + return scoped_buffer; +} + +void LocalClientTestBase::CopyShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer, ShapeIndex* index, Literal* literal) { + const Shape& shape = ShapeUtil::GetSubshape(shaped_buffer.shape(), *index); + if (ShapeUtil::IsTuple(shape)) { + *literal->mutable_shape() = shape; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + Literal* element_literal = literal->add_tuple_literals(); + index->push_back(i); + CopyShapedBufferToLiteral(shaped_buffer, index, element_literal); + index->pop_back(); + } + } else { + ASSERT_IS_OK(transfer_manager_->TransferLiteralFromDevice( + stream_executor_, shaped_buffer.buffer(*index), shape, shape, literal)); + } +} + +std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer) { + auto literal = MakeUnique(); + ShapeIndex index; + CopyShapedBufferToLiteral(shaped_buffer, &index, literal.get()); + return literal; +} + +std::unique_ptr +LocalClientTestBase::ShapedBufferToScopedShapedBuffer( + std::unique_ptr shaped_buffer, + DeviceMemoryAllocator* allocator) { + std::unique_ptr scoped_buffer = + ScopedShapedBuffer::MakeScopedShapedBuffer( + shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal()) + .ConsumeValueOrDie(); + // Deallocate the existing DeviceMemoryBase values in the newly created scoped + // buffer and replace them with the values from the shaped buffer. + for (perftools::gputools::DeviceMemoryBase& memory_base : + *scoped_buffer->mutable_buffers()) { + TF_CHECK_OK( + allocator->Deallocate(shaped_buffer->device_ordinal(), &memory_base)); + } + *scoped_buffer->mutable_buffers() = shaped_buffer->buffers(); + + TF_CHECK_OK( + scoped_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElement( + [&shaped_buffer](const ShapeIndex& index, bool is_leaf, + size_t* buffer_entry) -> ::tensorflow::Status { + if (is_leaf) { + *buffer_entry = + shaped_buffer->shape_index_to_buffer_entry().element( + index); + } + return tensorflow::Status::OK(); + })); + return scoped_buffer; +} + +LocalExecuteOptions LocalClientTestBase::DefaultLocalExecuteOptions() const { + return LocalExecuteOptions().set_allocator( + GetOrCreateAllocator(local_client_->platform())); +} + +std::unique_ptr LocalClientTestBase::ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments) { + return ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions()); +} + +std::unique_ptr LocalClientTestBase::ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options) { + return ShapedBufferToScopedShapedBuffer( + local_client_->ExecuteLocally(computation, arguments, options) + .ConsumeValueOrDie(), + options.allocator()); +} + +void LocalClientTestBase::ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result) { + ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions(), result); +} + +void LocalClientTestBase::ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, ShapedBuffer* result) { + ASSERT_IS_OK( + local_client_->ExecuteLocally(computation, arguments, options, result)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h new file mode 100644 index 0000000000..62916d50e3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class TestAllocator : public StreamExecutorMemoryAllocator { + public: + explicit TestAllocator(perftools::gputools::Platform* platform) + : StreamExecutorMemoryAllocator( + platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { + } + + StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + tensorflow::Status Deallocate( + int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + + // Return the number of allocations that have been performed. + int64 allocation_count() const; + int64 allocation_count(int device_ordinal) const; + + // Return the number of deallocations that have been performed. + int64 deallocation_count() const; + int64 deallocation_count(int device_ordinal) const; + + private: + mutable tensorflow::mutex count_mutex_; + + // Global counts of allocations and deallocations. + int64 allocation_count_ GUARDED_BY(count_mutex_) = 0; + int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0; + + // Per-device counts of allocations and deallocations. + std::map device_allocation_count_ GUARDED_BY(count_mutex_); + std::map device_deallocation_count_ GUARDED_BY(count_mutex_); +}; + +// A base class for tests which exercise the LocalClient interface. +class LocalClientTestBase : public ::testing::Test { + protected: + explicit LocalClientTestBase( + perftools::gputools::Platform* platform = nullptr); + + static TestAllocator* GetOrCreateAllocator( + perftools::gputools::Platform* platform); + + // Copy the given literal onto the default device and return a + // ScopedShapedBuffer. + std::unique_ptr LiteralToScopedShapedBuffer( + const Literal& literal); + // As above, but copy to a specific device. + std::unique_ptr LiteralToScopedShapedBuffer( + const Literal& literal, int device_ordinal); + + // Construct and return a literal containing the array represented by + // shaped_buffer. + std::unique_ptr ShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer); + + // Helper for converting a ShapedBuffer into a literal. + void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer, + ShapeIndex* index, Literal* literal); + + // Execute the given computation on the local client. With and without + // options. + std::unique_ptr ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options); + + // Returns a default set of execute options, configured to use allocator_ + // as the allocator. + LocalExecuteOptions DefaultLocalExecuteOptions() const; + + // Overloads which write result into the given buffer. + void ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + ShapedBuffer* result); + void ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const LocalExecuteOptions& options, ShapedBuffer* result); + + // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are + // deallocated when the object is destructed. + std::unique_ptr ShapedBufferToScopedShapedBuffer( + std::unique_ptr shaped_buffer, + DeviceMemoryAllocator* allocator); + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + // The allocator must live as long as the service which lives until the end of + // the process, so make the allocator static. + static TestAllocator* allocator_; + + perftools::gputools::StreamExecutor* stream_executor_; + TransferManager* transfer_manager_; + + LocalClient* local_client_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc new file mode 100644 index 0000000000..b520d89de3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class LogTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(LogTest, LogZeroValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); + builder.Log(x); + + ComputeAndCompareR3(&builder, Array3D(3, 0, 0), {}, + ErrorSpec(0.0001)); +} + +TEST_F(LogTest, LogTenValues) { + std::vector input = {-0.0, 1.0, 2.0, -3.0, -4.0, + 5.0, 6.0, -7.0, -8.0, 9.0}; + + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1(input); + builder.Log(x); + + std::vector expected; + for (float f : input) { + expected.push_back(std::log(f)); + } + + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc new file mode 100644 index 0000000000..014417a205 --- /dev/null +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -0,0 +1,589 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class MapTest : public ClientLibraryTestBase { + public: + explicit MapTest(perftools::gputools::Platform* platform = nullptr) + : ClientLibraryTestBase(platform, + /*disabled_pass_names=*/{"algsimp", "inline"}) {} + + // Creates a function that adds its scalar argument with the constant 1.0. + // + // x {R0F32} ----> (add) + // / + // 1.0f ---------/ + Computation CreateAdderToOne() { + ComputationBuilder mapped_builder(client_, TestName()); + auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = mapped_builder.ConstantR0(1.0); + auto adder_to_one = mapped_builder.Add(x, one); + auto computation_status = mapped_builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + Computation CreateMax() { + ComputationBuilder b(client_, TestName()); + auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + b.Max(lhs, rhs); + auto computation_status = b.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + // Creates a computation that accepts an F32 and returns T(1) (ignoring the + // argument). + template + Computation CreateScalarOne() { + ComputationBuilder mapped_builder(client_, "scalar_one"); + (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + mapped_builder.ConstantR0(1); + auto computation_status = mapped_builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + // Creates a function that multiplies its scalar argument by the constant 2.0 + // + // x {R0F32} ----> (mul) + // / + // 2.0f ---------/ + Computation CreateMulByTwo() { + ComputationBuilder mapped_builder(client_, TestName()); + auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto two = mapped_builder.ConstantR0(2.0); + auto mul_by_two = mapped_builder.Mul(x, two); + auto computation_status = mapped_builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + // Creates a function that adds its scalar argument with the constant 1.0 and + // then multiplies by the original element. + // + // /---------------\ + // / \ + // x {R0F32} ----> (add) ----> (mul) + // / + // 1.0f ---------/ + Computation CreateAdderToOneTimesItself() { + ComputationBuilder mapped_builder(client_, TestName()); + auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = mapped_builder.ConstantR0(1.0); + auto adder_to_one = mapped_builder.Add(x, one); + auto result = mapped_builder.Mul(x, adder_to_one); + auto computation_status = mapped_builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + // Creates a function that takes a single parameter and calls map with + // "embedded_computation" on it, and then adds "n" to the result. + // + // x {R0F32} -----------> (map) ----> (add) + // / / + // embedded_computation --/ n --/ + Computation CreateMapPlusN(const Computation& embedded_computation, float n) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto map = builder.Map({x}, embedded_computation); + auto constant_n = builder.ConstantR0(n); + auto add = builder.Add(map, constant_n); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + // Creates a binary function with signature (F32, F32) -> Pred + // defined by (x, y) -> x > y. + Computation CreateGt() { + ComputationBuilder b(client_, "Gt"); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto gt = b.Gt(x, y); + auto computation_status = b.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + // Creates a function that adds three scalar arguments + // + // x {R0F32} ----\ + // \ + // y {R0F32} ----> (add) ---> (add) + // / + // z {R0F32} ---------------/ + Computation CreateTernaryAdder() { + ComputationBuilder mapped_builder(client_, "TernaryAdder"); + auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); + auto xy = mapped_builder.Add(x, y); + auto xyz = mapped_builder.Add(xy, z); + auto computation_status = mapped_builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } +}; + +TEST_F(MapTest, MapEachElemPlusOneR0) { + // Applies lambda (x) (+ x 1)) to an input scalar. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateAdderToOne()); + + ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { + // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateAdderToOne()); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +TEST_F(MapTest, MapEachElemPlusOneR1S4) { + // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateAdderToOne()); + + ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, + {param0_data.get()}, ErrorSpec(0.01f)); +} + +TEST_F(MapTest, MapEachF32ElementToS32Constant) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateScalarOne()); + + ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); +} + +TEST_F(MapTest, MapEachF32ElementToU32Constant) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateScalarOne()); + + ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); +} + +TEST_F(MapTest, MapEachElemLongerChainR1) { + // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateAdderToOneTimesItself()); + + ComputeAndCompareR1( + &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, + {param0_data.get()}, ErrorSpec(0.01f)); +} + +XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { + // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then + // maps (lambda (x) (* x 2)) on the result. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map1 = builder.Map({param}, CreateAdderToOne()); + auto map2 = builder.Map({map1}, CreateMulByTwo()); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +TEST_F(MapTest, MapMultipleMapsR1S4) { + // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then + // maps (lambda (x) (* x 2)) on the result. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map1 = builder.Map({param}, CreateAdderToOne()); + auto map2 = builder.Map({map1}, CreateMulByTwo()); + + ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, + {param0_data.get()}, ErrorSpec(0.01f)); +} + +TEST_F(MapTest, MapEachElemPlusOneR2) { + // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = LiteralUtil::CreateR2( + {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param = builder.Parameter(0, param0_literal->shape(), "param0"); + auto map = builder.Map({param}, CreateAdderToOne()); + + Array2D expected_array( + {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); + ComputeAndCompareR2(&builder, expected_array, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +XLA_TEST_F(MapTest, ComplexNestedMaps) { + // Constructs a complex graph of embedded computations to test the computation + // lowering order. Python equivalent: + // + // embed1 = lambda x: x + 1 # x + 1 + // embed2 = lambda x: embed1(x) + 2 # x + 3 + // embed3 = lambda x: embed1(x) + 4 # x + 5 + // embed4 = lambda x: embed2(x) + embed3(x) # 2x + 8 + // embed5 = lambda x: embed2(x) + 6 # x + 9 + // result = embed5(42) + embed4(7) # (42 + 9) + (2 * 7 + 8) = 73 + + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + + auto embed1 = CreateAdderToOne(); + auto embed2 = CreateMapPlusN(embed1, 2.0); + auto embed3 = CreateMapPlusN(embed1, 4.0); + + ComputationBuilder embed4_builder(client_, "embed4"); + auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); + auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2); + auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3); + auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); + auto embed4_status = embed4_builder.Build(); + ASSERT_IS_OK(embed4_status.status()); + auto embed4 = embed4_status.ConsumeValueOrDie(); + + auto embed5 = CreateMapPlusN(embed2, 6.0); + + ComputationBuilder builder(client_, TestName()); + auto constant_42 = builder.ConstantR0(42.0); + auto constant_7 = builder.ConstantR0(7.0); + auto map_42 = builder.Map({constant_42}, embed5); + auto map_7 = builder.Map({constant_7}, embed4); + builder.Add(map_42, map_7); + + ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); +} + +TEST_F(MapTest, VersionedEmbeddedComputation) { + // Build a computation X, use it in a map, then add an additional operation to + // computation X and use it again in a different map. Verify that the proper + // versions of computation X are used in each of the maps. + + // Create a (embedded) computation which adds one to its parameter argument. + ComputationBuilder embedded_builder(client_, "EmbeddedComputation"); + auto param_0 = + embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + auto constant_one = embedded_builder.ConstantR0(1.0); + auto adder_to_one = embedded_builder.Add(param_0, constant_one); + auto computation_status = embedded_builder.Build(); + ASSERT_IS_OK(computation_status.status()); + auto embedded_computation = computation_status.ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto constant_vector = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + auto map_plus_1 = builder.Map({constant_vector}, embedded_computation); + + // Add another Add(1) operation to the existing embedded computation. This + // requires using the stub interface because the ComputationBuilder does not + // allow modification to the Computation objects after they have been built. + BinaryOpRequest request; + request.set_binop(BINOP_ADD); + *request.mutable_lhs() = adder_to_one; + *request.mutable_rhs() = constant_one; + OpRequest op_request; + *op_request.mutable_computation() = embedded_computation.handle(); + *op_request.mutable_binary_op_request() = request; + OpResponse response; + tensorflow::Status s = client_->stub()->Op(&op_request, &response); + ASSERT_TRUE(s.ok()); + + auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation); + + // The original vector has Add(1) applied to it with a map, followed by + // Add(1+1) resulting in a net Add(3). + ComputeAndCompareR1(&builder, {4.0, 5.0, 6.0, 7.0}, {}, + ErrorSpec(0.01f)); +} + +TEST_F(MapTest, MapBinaryAdder) { + // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_literal = + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto map = + builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder)); + + ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.01f)); +} + +// Adds two rank-2 arrays with different layouts. This test exercises a path +// for Map that used to fail in shape inference (b/28989438). +XLA_TEST_F(MapTest, AddWithMixedLayouts) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr param1_literal = + test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1}); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto map = + builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder)); + + Array2D expected(2, 2); + expected(0, 0) = 11; + expected(0, 1) = 22; + expected(1, 0) = 33; + expected(1, 1) = 44; + ComputeAndCompareR2(&builder, expected, + {param0_data.get(), param1_data.get()}); +} + +XLA_TEST_F(MapTest, AddR3_3x0x2) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr param1_literal = + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto map = + builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder)); + + ComputeAndCompareR3(&builder, Array3D(3, 0, 2), + {param0_data.get(), param1_data.get()}); +} + +TEST_F(MapTest, MapTernaryAdder) { + // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_literal = + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + std::unique_ptr param2_literal = + LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + std::unique_ptr param2_data = + client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); + auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder()); + + ComputeAndCompareR1( + &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, + {param0_data.get(), param1_data.get(), param2_data.get()}, + ErrorSpec(0.01f)); +} + +TEST_F(MapTest, MapGt) { + // Maps (x,y) -> x > y onto two R1F32 vectors. + ComputationBuilder b(client_, TestName()); + auto gt = CreateGt(); + b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt); + ComputeAndCompareR1(&b, {false, true}, {}); +} + +TEST_F(MapTest, NestedBinaryMap) { + Computation max_with_square; + { + // max_with_square(x) = do max(x, x^2) via a map. + ComputationBuilder b(client_, "max_with_square"); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + b.Map({x, b.Mul(x, x)}, CreateMax()); + auto computation_status = b.Build(); + ASSERT_IS_OK(computation_status.status()); + max_with_square = computation_status.ConsumeValueOrDie(); + } + ComputationBuilder b(client_, TestName()); + auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); + b.Map({input}, max_with_square); + ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); +} + +TEST_F(MapTest, MapOperantionWithBuildError) { + // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported + // type combination (F32 + U16) to test that the error is reported to the + // outermost ComputationBuilder. + ComputationBuilder builder(client_, TestName()); + + auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); + auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); + auto adder = sub_builder->Add(x, y); + auto error_add = sub_builder->BuildAndNoteError(); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_literal = + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto map = builder.Map({param0, param1}, error_add); + + StatusOr computation_status = builder.Build(); + ASSERT_TRUE(!computation_status.ok()); + EXPECT_MATCH(computation_status.status().ToString(), + testing::HasSubstr("error from: ErrorAdd: binary op with " + "different element types: f32[] and u16[]")); +} + +// MapTest disables inline and algsimp. MapTestWithFullOpt runs all +// optimizations. +using MapTestWithFullOpt = ClientLibraryTestBase; + +// Regression test for b/31466798. The inliner simplifies map(param0, param1, +// power) to power(param0, param1) without deleting the old subcomputation which +// is the same as the new entry computation. HloSubcomputationUnification used +// to have issues with such patterns and maybe invalidate the pointer to entry +// computation. +TEST_F(MapTestWithFullOpt, MapScalarPower) { + ComputationBuilder builder(client_, TestName()); + + auto sub_builder = builder.CreateSubBuilder("power"); + auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + sub_builder->Pow(x, y); + auto power = sub_builder->BuildAndNoteError(); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + builder.Map({param0, param1}, power); + + ComputeAndCompareR0(&builder, 32.0f, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.01f)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc new file mode 100644 index 0000000000..8aa4029440 --- /dev/null +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class MatOpsSimpleTest : public ClientLibraryTestBase { + protected: + Computation BuildSum() { + // sum(x, y) = x + y + ComputationBuilder builder(client_, "sum"); + auto x_value = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); + auto y_value = + builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y_value"); + builder.Add(x_value, y_value); + auto computation_status = builder.Build(); + TF_CHECK_OK(computation_status.status()); + return computation_status.ConsumeValueOrDie(); + } + + void TestLinspaceMax(int64 rows, int64 cols) { + float from = -128.0, to = 256.0; + std::unique_ptr> alhs = + MakeLinspaceArray2D(from, to, rows, cols); + auto arhs = MakeUnique>(rows, cols, 1.0); + + ComputationBuilder builder( + client_, + tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + auto max = builder.Max(lhs, rhs); + + Array2D aexpected(rows, cols); + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + aexpected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); + } + } + + ComputeAndCompareR2(&builder, aexpected, {}, ErrorSpec(1e-6)); + } +}; + +TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { + ComputationBuilder builder(client_, "exp_2x2"); + auto data = builder.ConstantR2({ + {1.0, 0.0}, // row 0 + {-1.0, 0.5}, // row 1 + }); + builder.Exp(data); + + std::unique_ptr expected = + LiteralUtil::CreateR2({{2.71828, 1.00000}, // row 0 + {0.36788, 1.64872}}); // row 1 + + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); +} + +TEST_F(MatOpsSimpleTest, MapTwoByTwo) { + Computation add_half; + { + // add_half(x) = x + 0.5 + ComputationBuilder builder(client_, "add_half"); + auto x_value = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); + auto half = builder.ConstantR0(0.5); + builder.Add(x_value, half); + auto computation_status = builder.Build(); + ASSERT_IS_OK(computation_status.status()); + add_half = computation_status.ConsumeValueOrDie(); + } + + ComputationBuilder builder(client_, "map_2x2"); + auto data = builder.ConstantR2({ + {1.0, 0.0}, // row 0 + {-1.0, 0.5}, // row 1 + }); + auto map = builder.Map({data}, add_half); + + std::unique_ptr expected = + LiteralUtil::CreateR2({{1.5, 0.5}, // row 0 + {-0.5, 1.0}}); // row 1 + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); +} + +TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { + ComputationBuilder builder(client_, "max_2x2"); + auto lhs = builder.ConstantR2({ + {7.0, 2.0}, // row 0 + {3.0, -4.0}, // row 1 + }); + auto rhs = builder.ConstantR2({ + {5.0, 6.0}, // row 0 + {1.0, -8.0}, // row 1 + }); + auto max = builder.Max(lhs, rhs); + + std::unique_ptr expected = + LiteralUtil::CreateR2({{7.0, 6.0}, // row 0 + {3.0, -4.0}}); // row 1 + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); +} + +TEST_F(MatOpsSimpleTest, Max1x1Linspace) { TestLinspaceMax(1, 1); } + +TEST_F(MatOpsSimpleTest, Max2x2Linspace) { TestLinspaceMax(2, 2); } + +TEST_F(MatOpsSimpleTest, Max3x3Linspace) { TestLinspaceMax(3, 3); } + +TEST_F(MatOpsSimpleTest, Max4x4Linspace) { TestLinspaceMax(4, 4); } + +TEST_F(MatOpsSimpleTest, Max6x6Linspace) { TestLinspaceMax(6, 6); } + +TEST_F(MatOpsSimpleTest, Max8x8Linspace) { TestLinspaceMax(8, 8); } + +TEST_F(MatOpsSimpleTest, Max12x12Linspace) { TestLinspaceMax(12, 12); } + +TEST_F(MatOpsSimpleTest, Max16x16Linspace) { TestLinspaceMax(16, 16); } + +TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); } + +TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); } + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc new file mode 100644 index 0000000000..2cd680399b --- /dev/null +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that slice operations can be performed. + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class SliceTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(SliceTest, Slice2D) { + ComputationBuilder builder(client_, "slice_2d"); + auto original = builder.ConstantR2( + {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); + builder.Slice(original, {2, 1}, {4, 3}); + + Array2D expected({{8.0f, 9.0f}, {11.0f, 12.0f}}); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); +} + +XLA_TEST_F(SliceTest, Slice3D) { + ComputationBuilder builder(client_, "slice_3d"); + Array3D array_3d( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); + auto original = builder.ConstantR3FromArray3D(array_3d); + builder.Slice(original, {0, 0, 1}, {2, 1, 2}); + + Array3D expected_3d({{{2.0f}}, {{6.0f}}}); + ComputeAndCompareR3(&builder, expected_3d, {}, ErrorSpec(0.000001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc new file mode 100644 index 0000000000..d3400b432f --- /dev/null +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -0,0 +1,420 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class PadTest : public ClientLibraryTestBase { + protected: + PadTest() { + // Initializes the padding configuration used for R4 tests. + // Pad only on the dimension 0 {low: 1, high: 0, interior: 2} and + // dimension 1 {low: 0, high: 2, interior: 1}. + auto dimension0 = r4_padding_on_dim0_dim1_.add_dimensions(); + dimension0->set_edge_padding_low(1); + dimension0->set_edge_padding_high(0); + dimension0->set_interior_padding(2); + auto dimension1 = r4_padding_on_dim0_dim1_.add_dimensions(); + dimension1->set_edge_padding_low(0); + dimension1->set_edge_padding_high(2); + dimension1->set_interior_padding(1); + auto dimension2 = r4_padding_on_dim0_dim1_.add_dimensions(); + dimension2->set_edge_padding_low(0); + dimension2->set_edge_padding_high(0); + dimension2->set_interior_padding(0); + auto dimension3 = r4_padding_on_dim0_dim1_.add_dimensions(); + dimension3->set_edge_padding_low(0); + dimension3->set_edge_padding_high(0); + dimension3->set_interior_padding(0); + } + + // Padding configuration for R4 that only pads dimension 0 and 1. + PaddingConfig r4_padding_on_dim0_dim1_; +}; + +// Tests a Pad() with a zero-element input and output. +XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { + ComputationBuilder b(client_, TestName()); + // Set up the padding configuration {low: 0, high: 0, interior: 0}. + PaddingConfig padding_config; + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + + b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); + ComputeAndCompareR1(&b, {}, {}, ErrorSpec(0.0001)); +} + +// Tests a Pad() with a zero-element input but a non-zero-element output. +XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { + ComputationBuilder b(client_, TestName()); + // Set up the padding configuration {low: 3, high: 0, interior: 1}. + PaddingConfig padding_config; + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(1); + dimension->set_edge_padding_high(4); + dimension->set_interior_padding(7); + + b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); + ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, + ErrorSpec(0.0001)); +} + +XLA_TEST_F(PadTest, Pad1DS3Array) { + ComputationBuilder b(client_, TestName()); + // Set up the padding configuration {low: 3, high: 0, interior: 1}. + PaddingConfig padding_config; + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(3); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(1); + + b.Pad(b.ConstantR1({1, 2, 3}), b.ConstantR0(0.1), + padding_config); + std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); + ComputeAndCompareR1(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(PadTest, Pad4D_2x0x3x2_FloatArray) { + ComputationBuilder b(client_, TestName()); + b.Pad(b.ConstantR4FromArray4D(Array4D(2, 0, 3, 2)), + b.ConstantR0(1.5), r4_padding_on_dim0_dim1_); + ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, + ErrorSpec(0.0001)); +} + +TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { + ComputationBuilder b(client_, TestName()); + auto input = MakeUnique>(1, 1, 3, 2); + Array2D input_xy({ + {1.0f, 2.0f}, // row 0 + {3.0f, 4.0f}, // row 1 + {5.0f, 6.0f}, // row 2 + }); + input->FillWithYX(input_xy); + + b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(1.5), + r4_padding_on_dim0_dim1_); + + auto expected = MakeUnique>(2, 3, 3, 2); + expected->Fill(1.5); + (*expected)(1, 0, 0, 0) = 1.0f; + (*expected)(1, 0, 0, 1) = 2.0f; + (*expected)(1, 0, 1, 0) = 3.0f; + (*expected)(1, 0, 1, 1) = 4.0f; + (*expected)(1, 0, 2, 0) = 5.0f; + (*expected)(1, 0, 2, 1) = 6.0f; + ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { + ComputationBuilder b(client_, TestName()); + + const float pad_value = 1.5f; + Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); + b.Pad(b.ConstantR4FromArray4D(input), b.ConstantR0(pad_value), + r4_padding_on_dim0_dim1_); + + auto expected = MakeUnique>(8, 5, 1, 1); + expected->Fill(pad_value); + (*expected)(1, 0, 0, 0) = 1.0f; + (*expected)(1, 2, 0, 0) = 2.0f; + (*expected)(4, 0, 0, 0) = 3.0f; + (*expected)(4, 2, 0, 0) = 4.0f; + (*expected)(7, 0, 0, 0) = 5.0f; + (*expected)(7, 2, 0, 0) = 6.0f; + ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { + ComputationBuilder b(client_, TestName()); + + PaddingConfig padding_config; + auto dimension0 = padding_config.add_dimensions(); + dimension0->set_edge_padding_low(0); + dimension0->set_edge_padding_high(0); + dimension0->set_interior_padding(0); + auto dimension1 = padding_config.add_dimensions(); + dimension1->set_edge_padding_low(0); + dimension1->set_edge_padding_high(0); + dimension1->set_interior_padding(0); + auto dimension2 = padding_config.add_dimensions(); + dimension2->set_edge_padding_low(2); + dimension2->set_edge_padding_high(1); + dimension2->set_interior_padding(0); + auto dimension3 = padding_config.add_dimensions(); + dimension3->set_edge_padding_low(2); + dimension3->set_edge_padding_high(3); + dimension3->set_interior_padding(0); + + const Layout layout = LayoutUtil::MakeLayout({0, 1, 2, 3}); + + const float pad_value = -5.123f; + Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); + input = LiteralUtil::Relayout(*input, layout); + + b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + + Array4D expected_array(1, 1, 5, 8); + expected_array.Fill(pad_value); + expected_array(0, 0, 2, 2) = 1.0f; + expected_array(0, 0, 2, 3) = 2.0f; + expected_array(0, 0, 2, 4) = 3.0f; + expected_array(0, 0, 3, 2) = 4.0f; + expected_array(0, 0, 3, 3) = 5.0f; + expected_array(0, 0, 3, 4) = 6.0f; + ComputeAndCompareR4(&b, expected_array, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { + ComputationBuilder b(client_, TestName()); + + PaddingConfig padding_config; + auto dimension0 = padding_config.add_dimensions(); + dimension0->set_edge_padding_low(0); + dimension0->set_edge_padding_high(0); + dimension0->set_interior_padding(0); + auto dimension1 = padding_config.add_dimensions(); + dimension1->set_edge_padding_low(0); + dimension1->set_edge_padding_high(0); + dimension1->set_interior_padding(0); + auto dimension2 = padding_config.add_dimensions(); + dimension2->set_edge_padding_low(2); + dimension2->set_edge_padding_high(2); + dimension2->set_interior_padding(1); + auto dimension3 = padding_config.add_dimensions(); + dimension3->set_edge_padding_low(2); + dimension3->set_edge_padding_high(2); + dimension3->set_interior_padding(0); + + const Layout layout = LayoutUtil::MakeLayout({0, 1, 2, 3}); + + const float pad_value = -5.123f; + Array4D input_array(1, 25, 7, 7); + input_array.Fill(pad_value); + input_array(0, 0, 0, 0) = 1.0f; + input_array(0, 24, 6, 6) = 2.0f; + input_array(0, 17, 2, 5) = 3.0f; + auto input = LiteralUtil::CreateR4FromArray4D(input_array); + input = LiteralUtil::Relayout(*input, layout); + + b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + + Array4D expected_array(1, 25, 17, 11); + expected_array.Fill(pad_value); + expected_array(0, 0, 2, 2) = 1.0f; + expected_array(0, 24, 14, 8) = 2.0f; + expected_array(0, 17, 6, 7) = 3.0f; + ComputeAndCompareR4(&b, expected_array, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(PadTest, Pad4DU8Array) { + ComputationBuilder b(client_, TestName()); + auto input = MakeUnique>(1, 1, 3, 2); + Array2D input_xy({ + {1, 2}, // row 0 + {3, 4}, // row 1 + {5, 6}, // row 2 + }); + input->FillWithYX(input_xy); + + b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(35), + r4_padding_on_dim0_dim1_); + + auto expected = MakeUnique>(2, 3, 3, 2); + expected->Fill(35); + (*expected)(1, 0, 0, 0) = 1; + (*expected)(1, 0, 0, 1) = 2; + (*expected)(1, 0, 1, 0) = 3; + (*expected)(1, 0, 1, 1) = 4; + (*expected)(1, 0, 2, 0) = 5; + (*expected)(1, 0, 2, 1) = 6; + ComputeAndCompareR4(&b, *expected, {}); +} + +XLA_TEST_F(PadTest, Pad4DPredArray) { + ComputationBuilder b(client_, TestName()); + + // Since bool is currently not well supported, use Broadcast operation to + // create the operand for Pad. + auto input = b.Broadcast(b.ConstantR0(true), {1, 1, 3, 2}); + auto padded = + b.Pad(input, b.ConstantR0(false), r4_padding_on_dim0_dim1_); + + // For the same reason, use Select to convert boolean values to int32. + auto zeros = MakeUnique>(2, 3, 3, 2); + auto ones = MakeUnique>(2, 3, 3, 2); + zeros->Fill(0); + ones->Fill(1); + b.Select(padded, b.ConstantR4FromArray4D(*ones), + b.ConstantR4FromArray4D(*zeros)); + + auto expected = MakeUnique>(2, 3, 3, 2); + expected->Fill(0); + (*expected)(1, 0, 0, 0) = 1; + (*expected)(1, 0, 0, 1) = 1; + (*expected)(1, 0, 1, 0) = 1; + (*expected)(1, 0, 1, 1) = 1; + (*expected)(1, 0, 2, 0) = 1; + (*expected)(1, 0, 2, 1) = 1; + ComputeAndCompareR4(&b, *expected, {}); +} + +XLA_TEST_F(PadTest, Large2DPad) { + ComputationBuilder b(client_, TestName()); + + auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {4, 4}), "input"); + PaddingConfig padding_config = MakeNoPaddingConfig(2); + for (int dim : {0, 1}) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + 98 + 100 * (1 - dim)); + padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + + 100 * dim); + } + auto padded = b.Pad(input, b.ConstantR0(0.0f), padding_config); + + auto ones = MakeUnique>(4, 4); + ones->Fill(1.0f); + auto input_literal = LiteralUtil::CreateR2FromArray2D(*ones); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); + ComputeAndCompareR2(&b, *expected, {input_data.get()}); +} + +XLA_TEST_F(PadTest, AllTypes2DPad) { + ComputationBuilder b(client_, TestName()); + + constexpr int64 in_rows = 35; + constexpr int64 in_cols = 35; + auto input = + b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + PaddingConfig padding_config = MakeNoPaddingConfig(2); + padding_config.mutable_dimensions(0)->set_edge_padding_low(7); + padding_config.mutable_dimensions(0)->set_edge_padding_high(5); + padding_config.mutable_dimensions(0)->set_interior_padding(3); + padding_config.mutable_dimensions(1)->set_edge_padding_low(6); + padding_config.mutable_dimensions(1)->set_edge_padding_high(4); + padding_config.mutable_dimensions(1)->set_interior_padding(2); + auto padded = b.Pad(input, b.ConstantR0(3.14f), padding_config); + + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(0.0f); + auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); + ComputeAndCompareR2(&b, *expected, {input_data.get()}, + ErrorSpec{0.0001}); +} + +XLA_TEST_F(PadTest, High2DPad) { + ComputationBuilder b(client_, TestName()); + + constexpr int64 in_rows = 129; + constexpr int64 in_cols = 129; + constexpr int64 low_padding = 0; + int64 high_padding[2] = {5, 7}; + constexpr int64 interior_padding = 0; + auto input = + b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + PaddingConfig padding_config = MakeNoPaddingConfig(2); + for (int dim : {0, 1}) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low(low_padding); + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + high_padding[dim]); + padding_config.mutable_dimensions(dim)->set_interior_padding( + interior_padding); + } + auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputeAndCompareR2(&b, *expected, {input_data.get()}, + ErrorSpec(0.0001)); +} + +// Regression test for b/31827337. +XLA_TEST_F(PadTest, ReducePad) { + ComputationBuilder b(client_, TestName()); + auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "input"); + + Computation add_f32 = CreateScalarAddComputation(F32, &b); + auto reduce = b.Reduce(input, b.ConstantR0(0.0), add_f32, {0}); + + PaddingConfig padding_config = MakeNoPaddingConfig(3); + padding_config.mutable_dimensions(0)->set_edge_padding_low(1); + padding_config.mutable_dimensions(0)->set_edge_padding_high(1); + auto pad = b.Pad(reduce, b.ConstantR0(0.0), padding_config); + + auto ones = MakeUnique>(2, 2, 2, 2); + ones->Fill(1.0); + auto input_literal = LiteralUtil::CreateR4FromArray4D(*ones); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, + {{2.0, 2.0}, {2.0, 2.0}}, + {{2.0, 2.0}, {2.0, 2.0}}, + {{0.0, 0.0}, {0.0, 0.0}}}); + ComputeAndCompareR3(&b, expected, {input_data.get()}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc new file mode 100644 index 0000000000..2f05576cee --- /dev/null +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -0,0 +1,357 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ParamsTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(ParamsTest, ConstantR0F32Param) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR0(3.14159f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + + ComputeAndCompareR0(&builder, 3.14159f, {param0_data.get()}, + ErrorSpec(0.0001f)); +} + +XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0"); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({3.14f, -100.25f}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); + + ComputeAndCompareR1(&builder, {3.14f, -100.25f}, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +XLA_TEST_F(ParamsTest, ConstantR1U8Param) { + ComputationBuilder builder(client_, TestName()); + string str("hello world"); + std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto p = builder.Parameter( + 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), "param0"); + + ComputeAndCompareR1U8(&builder, str, {param0_data.get()}); +} + +XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); + + ComputeAndCompareR2(&builder, Array2D(3, 0), + {param0_data.get()}, ErrorSpec(0.01f)); +} + +XLA_TEST_F(ParamsTest, ConstantR2F32Param) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = LiteralUtil::CreateR2( + {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); + + Array2D expected_array( + {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); + ComputeAndCompareR2(&builder, expected_array, {param0_data.get()}, + ErrorSpec(0.01f)); +} + +XLA_TEST_F(ParamsTest, TwoParameters) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr param0_data = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr param1_data = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + + // Use both parameters + // + // {1, 2} + {10, 20} = {11, 22} + auto sum = builder.Add(param0, param1); + sum = builder.Add(param0, param1); + + // Use only the second parameter again, to show that it can be used + // twice and to make the computation asymmetric in the two + // parameters to test that the parameters are not swapped. + // + // {11, 22} * {10, 20} = {110, 440} + auto prod = builder.Mul(sum, param1); + + ComputeAndCompareR1(&builder, {110, 440}, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001f)); +} + +XLA_TEST_F(ParamsTest, MissingParameter) { + // Test that an error is returned when a computation with an incomplete set of + // parameters (parameter numbers not contiguous from 0) is executed. + std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto execute_status = client_->Execute(computation, {data.get(), data.get()}, + /*output_layout=*/nullptr, + /*execution_profile=*/nullptr); + ASSERT_EQ(execute_status.status().code(), + tensorflow::error::FAILED_PRECONDITION); +} + +XLA_TEST_F(ParamsTest, UnusedParameter) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr param0_data = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr param1_data = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + + ComputeAndCompareR1(&builder, {10, 20}, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001f)); +} + +XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { + // Build a computation with a couple unused parameters which are used in an + // unused expression. + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr param0_data = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = + LiteralUtil::CreateR1({10, 20, 30}); + std::unique_ptr param1_data = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + auto param2 = builder.Parameter(2, literal1->shape(), "param2"); + + // This add is unused. + builder.Add(param1, param2); + + builder.Neg(param0); + + ComputeAndCompareR1( + &builder, {-1, -2}, + {param0_data.get(), param1_data.get(), param1_data.get()}, + ErrorSpec(0.0001f)); +} + +XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { + ComputationBuilder builder(client_, TestName()); + constexpr int size = 8 * 128 * 2; + + std::vector init_value = {{0, 1}}; + init_value.resize(size); + ComputationDataHandle sum_handle = builder.ConstantR1(init_value); + std::vector sum = {{0, 1}}; + sum.resize(size); + + std::vector> param_data_owner; + + constexpr int parameter_count = 100; + for (int i = 0; i < parameter_count; ++i) { + const float entry0 = i; + const float entry1 = 2 * i; + sum[0] += entry0; + sum[1] += entry1; + + std::vector sum_value = {{entry0, entry1}}; + sum_value.resize(size); + std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + param_data_owner.push_back( + client_->TransferToServer(*literal).ConsumeValueOrDie()); + ComputationDataHandle param = + builder.Parameter(i, literal->shape(), "param"); + sum_handle = builder.Add(sum_handle, param); + } + + std::vector param_data; + for (const std::unique_ptr& data : param_data_owner) { + param_data.push_back(data.get()); + } + + ComputeAndCompareR1(&builder, sum, param_data, ErrorSpec(0.0001f)); +} + +XLA_TEST_F(ParamsTest, + DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) { + ComputationBuilder builder(client_, TestName()); + + Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3}); + auto input = builder.Parameter(0, tuple_shape, "input"); + auto lhs = builder.GetTupleElement(input, 0); + auto rhs = builder.GetTupleElement(input, 1); + builder.Add(lhs, rhs); + + std::unique_ptr data = + client_ + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR1({4, 5, 6}).get(), + })) + .ConsumeValueOrDie(); + + std::vector arguments = {data.get()}; + const std::vector expected = {1 + 4, 2 + 5, 3 + 6}; + ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); +} + +// Verifies that passing a 2x2 with {0, 1} layout returns the same value back +// when (transferred to the server and) passed through a parameter. +XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { + std::unique_ptr literal = LiteralUtil::CreateR2({ + {1, 2}, {3, 4}, + }); + *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + ComputationBuilder builder(client_, TestName()); + builder.Parameter(0, literal->shape(), "input"); + + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); +} + +// As above, but for {1, 0} layout. +XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { + std::unique_ptr literal = LiteralUtil::CreateR2({ + {1, 3}, {2, 4}, + }); + *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + ComputationBuilder builder(client_, TestName()); + builder.Parameter(0, literal->shape(), "input"); + + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); +} + +XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { + std::unique_ptr literal = LiteralUtil::CreateR2({ + {1, 3}, {2, 4}, + }); + const Shape original = literal->shape(); + { + // Reverse the layout present in original, and make that the layout of the + // literal. + std::vector original_layout( + original.layout().minor_to_major().begin(), + original.layout().minor_to_major().end()); + std::reverse(original_layout.begin(), original_layout.end()); + *literal->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout(original_layout); + ASSERT_EQ(2, LiteralUtil::Get(*literal, {0, 1})); + } + // Use the original shape in building the computation. + ComputationBuilder builder(client_, TestName()); + auto input = builder.Parameter(0, original, "input"); + // Use the slice operator to get an off-diagonal element. + builder.Slice(input, {0, 1}, {1, 2}); + + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + // Check that we got the off-diagonal value that we expected. + Array2D expected(1, 1); + expected(0, 0) = 2; + ComputeAndCompareR2(&builder, expected, {data.get()}, ErrorSpec(1e-3)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc new file mode 100644 index 0000000000..96393c41e8 --- /dev/null +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Miscellaneous tests with the PRED type that don't fit anywhere else. +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class PredTest : public ClientLibraryTestBase { + protected: + void TestCompare(bool lhs, bool rhs, bool expected, + ComputationDataHandle (ComputationBuilder::*op)( + const ComputationDataHandle&, + const ComputationDataHandle&, + tensorflow::gtl::ArraySlice)) { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle lhs_op = builder.ConstantR0(lhs); + ComputationDataHandle rhs_op = builder.ConstantR0(rhs); + ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + ComputeAndCompareR0(&builder, expected, {}); + } +}; + +TEST_F(PredTest, ConstantR0PredTrue) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR0(true); + ComputeAndCompareR0(&builder, true, {}); +} + +TEST_F(PredTest, ConstantR0PredFalse) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR0(false); + ComputeAndCompareR0(&builder, false, {}); +} + +TEST_F(PredTest, ConstantR0PredCompareEq) { + TestCompare(true, false, false, &ComputationBuilder::Eq); +} + +TEST_F(PredTest, ConstantR0PredCompareNe) { + TestCompare(true, false, true, &ComputationBuilder::Ne); +} + +TEST_F(PredTest, ConstantR0PredCompareLe) { + TestCompare(true, false, false, &ComputationBuilder::Le); +} + +TEST_F(PredTest, ConstantR0PredCompareLt) { + TestCompare(true, false, false, &ComputationBuilder::Lt); +} + +TEST_F(PredTest, ConstantR0PredCompareGe) { + TestCompare(true, false, true, &ComputationBuilder::Ge); +} + +TEST_F(PredTest, ConstantR0PredCompareGt) { + TestCompare(true, false, true, &ComputationBuilder::Gt); +} + +TEST_F(PredTest, ConstantR1Pred) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({true, false, false, true}); + ComputeAndCompareR1(&builder, {true, false, false, true}, {}); +} + +TEST_F(PredTest, ConstantR2Pred) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2({{false, true, true}, {true, false, false}}); + const string expected = R"(pred[2,3] { + { 011 }, + { 100 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc new file mode 100644 index 0000000000..8d77b3dd61 --- /dev/null +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class PrngTest : public ClientLibraryTestBase { + protected: + template + void UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims); + void BernoulliTest(float p, tensorflow::gtl::ArraySlice dims); +}; + +template +void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { + ComputationBuilder builder(client_, TestName()); + builder.RngUniform( + builder.ConstantR0(a), builder.ConstantR0(b), + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); + + auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + LiteralUtil::EachCell(*actual, + [=](tensorflow::gtl::ArraySlice, T value) { + EXPECT_LE(a, value); + EXPECT_GE(b, value); + }); +} + +void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { + ComputationBuilder builder(client_, TestName()); + auto shape = ShapeUtil::MakeShape(U32, dims); + builder.RngBernoulli(builder.ConstantR0(p), shape); + + TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); + constexpr uint64 kTestSeed = 42; + TF_ASSIGN_OR_ASSERT_OK( + auto actual, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr, + /*seed=*/kTestSeed)); + EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + int32 sum = 0; + LiteralUtil::EachCell( + *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { + EXPECT_TRUE(value == 0 || value == 1); + sum += value; + }); + int32 total = ShapeUtil::ElementsIn(shape); + float p_tilde = sum / static_cast(total); + + // Test within expected range using normal approximation. The test uses a + // fixed seed and has a fixed output per p and backend. Using the normal + // approximation as this test is invoked for different `p` and the different + // backends could use different random number generators and produce different + // values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96. + float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total); + EXPECT_GE(p_tilde, p - normal_approximation_term); + EXPECT_LE(p_tilde, p + normal_approximation_term); +} + +// Uniform random number generation tests +XLA_TEST_F(PrngTest, ScalarU01) { UniformTest(0, 1, {}); } +XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest(0, 1, {0}); } +XLA_TEST_F(PrngTest, TenValuesU01) { UniformTest(0, 1, {10}); } +XLA_TEST_F(PrngTest, TenValuesU37) { UniformTest(3, 7, {10}); } +XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest(0, 1, {0, 20}); } +XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } +XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } + +XLA_TEST_F(PrngTest, MapUsingRng) { + // Build a x -> (x + U[0,1)) computation. + auto build_sum_rng = [this](ComputationBuilder& builder) { + auto b = builder.CreateSubBuilder("sum_with_rng"); + auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input"); + b->Add(x, + b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), + ShapeUtil::MakeShape(F32, {}))); + return b->BuildAndNoteError(); + }; + + ComputationBuilder builder(client_, TestName()); + std::unique_ptr param0_literal = + LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr param0_data, + client_->TransferToServer(*param0_literal)); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto fn = build_sum_rng(builder); + builder.Map({param0}, fn); + + TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); + TF_ASSIGN_OR_ASSERT_OK( + auto actual, + client_->ExecuteAndTransfer(computation, + /*arguments=*/{param0_data.get()}, nullptr, + nullptr, /*seed=*/125)); + EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size()); + for (int i = 0; i < param0_literal->f32s_size(); ++i) { + EXPECT_GE(actual->f32s(i), param0_literal->f32s(i)); + EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f); + } +} + +// This tests demonstrates the global seeding behaviour. +// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is +// fixed (i.e., there is a single output for a given seed); +// * If no seed is passed in then the output of every call can be different; +XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { + // Build a U[0,1) computation. + auto build_computation = [this]() { + ComputationBuilder builder(client_, TestName()); + builder.RngUniform(builder.ConstantR0(0), + builder.ConstantR0(1), + ShapeUtil::MakeShape(F32, {10})); + return builder.Build(); + }; + + std::unique_ptr result1; + { + TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); + TF_ASSIGN_OR_ASSERT_OK( + result1, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr, + /*seed=*/42)); + } + std::unique_ptr result2; + std::unique_ptr result3; + { + TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); + TF_ASSIGN_OR_ASSERT_OK( + result2, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr, + /*seed=*/42)); + TF_ASSIGN_OR_ASSERT_OK( + result3, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr, + /*seed=*/42)); + } + + std::unique_ptr result4; + std::unique_ptr result5; + std::unique_ptr result6; + { + TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); + TF_ASSIGN_OR_ASSERT_OK( + result4, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr, + /*seed=*/65)); + TF_ASSIGN_OR_ASSERT_OK( + result5, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr)); + TF_ASSIGN_OR_ASSERT_OK( + result6, + client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + /*shape_with_output_layout=*/nullptr, + /*execution_profile=*/nullptr)); + } + + LiteralTestUtil::ExpectEqual(*result1, *result2); + LiteralTestUtil::ExpectEqual(*result1, *result3); + LiteralTestUtil::ExpectNotEqual(*result1, *result4); + LiteralTestUtil::ExpectNotEqual(*result4, *result5); + LiteralTestUtil::ExpectNotEqual(*result5, *result6); +} + +// Bernoulli random number generation tests +XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); } +XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); } + +XLA_TEST_F(PrngTest, TenValuesN01) { + ComputationBuilder builder(client_, TestName()); + builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), + ShapeUtil::MakeShape(F32, {10})); + + ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + // TODO(b/25995601): Test that resultant values are reasonable +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc new file mode 100644 index 0000000000..eb7e63705b --- /dev/null +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class QueryInferredShapeTest : public ClientLibraryTestBase {}; + +TEST_F(QueryInferredShapeTest, OnePlusOneShape) { + ComputationBuilder builder(client_, "one_plus_one"); + auto one = builder.ConstantR0(1.0); + auto result = builder.Add(one, one); + StatusOr> shape_status = builder.GetShape(result); + ASSERT_IS_OK(shape_status.status()); + auto shape = shape_status.ConsumeValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(*shape, ShapeUtil::MakeShape(F32, {}))); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc new file mode 100644 index 0000000000..f3d8da5c8c --- /dev/null +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that multi-dimensional arrays can be reduced among various +// user-provided dimensions. +// +// Note that comments for these tests are white-box in that they talk about the +// default data layout. +// +// The test space for reductions is the cartesian product of: +// +// x +// x +// + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ReduceTest : public ClientLibraryTestBase { + protected: + ReduceTest() { + // Implementation note: layed out z >> y >> x by default. + // clang-format off + literal_2d_ = LiteralUtil::CreateR2({ + // x0 x1 x2 + { 1.f, 2.f, 3.f}, // y0 + { 4.f, 5.f, 6.f}, // y1 + }); + literal_3d_ = LiteralUtil::CreateR3Projected({ + // x0 x1 x2 + { 1.f, 2.f, 3.f}, // y0 + { 4.f, 5.f, 6.f}, // y1 + }, 4); + // clang-format on + CHECK(ShapeUtil::Equal( + literal_3d_->shape(), + ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) + << literal_3d_->shape().ShortDebugString(); + } + + // Runs an R1 => R0 reduction test with the given number of elements. + void RunR1ToR0Test(int64 element_count) { + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + std::vector input_data(element_count); + for (int64 i = 0; i < element_count; ++i) { + input_data[i] = rand_r(&seed_) % 3; + if (rand_r(&seed_) % 2 == 0) { + input_data[i] *= -1; + } + } + std::unique_ptr input_literal = + LiteralUtil::CreateR1(AsSlice(input_data)); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + float expected = 0.0; + for (float item : input_data) { + expected += item; + } + ComputeAndCompareR0(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.001)); + } + + // Runs an R2 => R0 reduction test with the given number of (rows, cols). + void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); + + Array2D input_data(rows, cols); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = LiteralUtil::Relayout( + *input_literal, LayoutUtil::MakeLayout({minor, major})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + float expected = 0.0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + for (int64 colno = 0; colno < cols; ++colno) { + expected += input_data(rowno, colno); + } + } + ComputeAndCompareR0(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); + } + + // Runs an R2 => R1 reduction test with the given number of (rows, cols). + void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + Array2D input_data(rows, cols); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = LiteralUtil::Relayout( + *input_literal, LayoutUtil::MakeLayout({minor, major})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::vector expected; + for (int64 colno = 0; colno < cols; ++colno) { + float column_sum = 0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_sum += input_data(rowno, colno); + } + expected.push_back(column_sum); + } + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); + } + + std::unique_ptr literal_2d_; + std::unique_ptr literal_3d_; + uint32 seed_ = 0xdeadbeef; +}; + +XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); } +XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); } +XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); } +XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); } +XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } +XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); } +XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); } +XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); } +XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); } +XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); } +XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) { + RunR1ToR0Test(16 * 1024 + 1); +} + +XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); } +XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); } +XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); } +XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); } +XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); } +XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); } +XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); } +XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); } +XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); } +XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) { + RunR2ToR0Test(111, 50, 0, 1); +} +XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); } +XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); } + +// Disabled due to b/33245142. Failed on 2016-11-30. +// XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); } +XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); } +XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); } +// Disabled due to b/33245142. Failed on 2016-11-30. +// XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); } +XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); } +XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); } +XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); } +XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); } +XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); } +XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { + RunR2ToR1Test(111, 50, 0, 1); +} +XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } +XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } + +XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { + const int64 rows = 111, cols = 50; + + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + auto log_ = builder.Log(input); + builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + Array2D input_data(rows, cols); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = + LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::vector expected; + for (int64 colno = 0; colno < cols; ++colno) { + float column_sum = 0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_sum += log(input_data(rowno, colno)); + } + expected.push_back(column_sum); + } + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); +} + +struct BoundsLayout { + std::vector bounds; + std::vector layout; + std::vector reduce_dims; +}; + +void PrintTo(const BoundsLayout& spec, std::ostream* os) { + *os << tensorflow::strings::Printf( + "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(), + spec.bounds.size() - spec.reduce_dims.size(), + tensorflow::str_util::Join(spec.bounds, "x").c_str(), + tensorflow::str_util::Join(spec.layout, "").c_str(), + tensorflow::str_util::Join(spec.reduce_dims, "").c_str()); +} + +// Add-reduces a broadcasted scalar matrix among dimension 1 and 0. +XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { + ComputationBuilder builder(client_, TestName()); + auto add = CreateScalarAddComputation(F32, &builder); + auto scalar = builder.ConstantR0(42.0); + auto broacasted = builder.Broadcast(scalar, {500, 500}); + builder.Reduce(broacasted, builder.ConstantR0(0.0f), add, {0, 1}); + + float expected = 42.0f * static_cast(500 * 500); + ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Max-reduces a broadcasted scalar matrix among dimension 1 and 0. +XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { + ComputationBuilder builder(client_, TestName()); + auto max = CreateScalarMaxComputation(F32, &builder); + auto scalar = builder.ConstantR0(42.0); + auto broacasted = builder.Broadcast(scalar, {500, 500}); + builder.Reduce(broacasted, builder.ConstantR0(0.0f), max, {0, 1}); + + float expected = 42.0f; + ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Max-reduces a matrix among dimension 1 and 0. +XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { + ComputationBuilder builder(client_, TestName()); + auto max = CreateScalarMaxComputation(F32, &builder); + Array2D input(300, 250); + input.FillRandom(214.0f); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + builder.Reduce(builder.ConstantLiteral(*input_literal), + builder.ConstantR0(FLT_MIN), max, {0, 1}); + auto input_max = FLT_MIN; + input.Each( + [&](int64, int64, float* v) { input_max = std::max(input_max, *v); }); + ComputeAndCompareR0(&builder, input_max, {}, ErrorSpec(0.0001)); +} + +// Min-reduces matrix among dimension 1 and 0. +XLA_TEST_F(ReduceTest, MinReduce2DToR0) { + ComputationBuilder builder(client_, TestName()); + auto min = CreateScalarMinComputation(F32, &builder); + Array2D input(150, 130); + input.FillRandom(214.0f); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + builder.Reduce(builder.ConstantLiteral(*input_literal), + builder.ConstantR0(FLT_MAX), min, {0, 1}); + + auto input_min = FLT_MAX; + input.Each( + [&](int64, int64, float* v) { input_min = std::min(input_min, *v); }); + ComputeAndCompareR0(&builder, input_min, {}, ErrorSpec(0.0001)); +} + +// Reduces a matrix among dimension 1. +XLA_TEST_F(ReduceTest, Reduce2DAmong1) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_2d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + + std::vector expected = {6.f, 15.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { + // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_2d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + + ComputeAndCompareR0(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4)); +} + +// Tests 2D matrix ReduceToRow operation. +XLA_TEST_F(ReduceTest, Reduce2DAmongY) { + ComputationBuilder builder(client_, "reduce_among_y"); + auto m = builder.ConstantLiteral(*literal_2d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + + std::vector expected = {5.f, 7.f, 9.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_3d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {1, 2}); + + std::vector expected = {21.f, 21.f, 21.f, 21.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_3d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + + std::vector expected = {20.f, 28.f, 36.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, ReduceR3ToR0) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_3d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1, 2}); + + float expected = 21.0f * 4.0; + ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_3d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + + // clang-format off + Array2D expected({ + {4.f, 8.f, 12.f}, + {16.f, 20.f, 24.f}, + }); + // clang-format on + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_3d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + + // clang-format off + Array2D expected({ + {5.f, 7.f, 9.f}, + {5.f, 7.f, 9.f}, + {5.f, 7.f, 9.f}, + {5.f, 7.f, 9.f}, + }); + // clang-format on + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantLiteral(*literal_3d_); + auto add = CreateScalarAddComputation(F32, &builder); + builder.Reduce(m, builder.ConstantR0(0.0f), add, {2}); + + // clang-format off + Array2D expected({ + {6.f, 15.f}, + {6.f, 15.f}, + {6.f, 15.f}, + {6.f, 15.f}, + }); + // clang-format on + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); +} + +class ReduceR3ToR2Test : public ReduceTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { + ComputationBuilder builder(client_, TestName()); + const auto& bounds = GetParam().bounds; + Array3D input_array(bounds[0], bounds[1], bounds[2]); + input_array.FillRandom(3.14f, 0.05); + + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); + input_literal = LiteralUtil::Relayout( + *input_literal, LayoutUtil::MakeLayout(GetParam().layout)); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + auto input_activations = + builder.Parameter(0, input_literal->shape(), "input"); + Computation add = CreateScalarAddComputation(F32, &builder); + auto sum = builder.Reduce(input_activations, builder.ConstantR0(0.0f), + add, GetParam().reduce_dims); + + auto expected = + ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims, + [](float a, float b) { return a + b; }); + + ComputeAndCompareR2(&builder, *expected, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +INSTANTIATE_TEST_CASE_P( + ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test, + // Specifies (shape, layout, reduction dimensions). + ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}}, + BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}}, + BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}}, + // These should be simplified into a reshape. + BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}}, + BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}}, + BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}}, + BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}}, + BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}}, + BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}}, + BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}}, + BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}}, + BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}}, + BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}}, + BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}}, + BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}}, + BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}}, + BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}}, + BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}}, + BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}})); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc new file mode 100644 index 0000000000..f48c14dfc6 --- /dev/null +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -0,0 +1,445 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests the reduce-window XLA operation. + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ReduceWindowTest : public ClientLibraryTestBase { + public: + ReduceWindowTest() : builder_(client_, TestName()) {} + + void ReduceWindowAdd(ComputationDataHandle input, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + builder_.ReduceWindow(input, builder_.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder_), + window_dimensions, window_strides, padding); + } + + void ReduceWindowMax(ComputationDataHandle input, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + builder_.ReduceWindow( + input, builder_.ConstantLiteral(LiteralUtil::MinValue(F32)), + CreateScalarMax(), window_dimensions, window_strides, padding); + } + + void ReduceWindowMin(ComputationDataHandle input, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + builder_.ReduceWindow(input, + builder_.ConstantLiteral(LiteralUtil::MaxValue(F32)), + CreateScalarMinComputation(F32, &builder_), + window_dimensions, window_strides, padding); + } + + ComputationBuilder builder_; +}; + +XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { + Array4D input_array(1, 0, 2, 1); + + const auto input = builder_.ConstantR4FromArray4D(input_array); + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, + {1, 1, 1, 1}, padding); + + ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, NonSquareSmall) { + Array4D input_array(1, 2, 2, 1); + input_array.FillRandom(2.f); + + const auto input = builder_.ConstantR4FromArray4D(input_array); + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, + {1, 1, 1, 1}, padding); + + ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, MiddleDimsSmall) { + Array4D input_array(1, 3, 3, 1); + input_array.FillRandom(2.f); + + const auto input = builder_.ConstantR4FromArray4D(input_array); + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, + {1, 2, 2, 1}, padding); + + ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, Along2ndMinorDim) { + Array4D input_array(3, 6, 7, 32); + input_array.FillRandom(2.f); + + // The parameters of this reduction mimic feature norm (e.g. LRN). + int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd + const auto input = builder_.ConstantR4FromArray4D(input_array); + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); + + ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { + Array4D input_array(9, 12, 4, 89); + input_array.FillRandom(2.0f); + + int win_len = 3; + int win_stride = 2; + + const auto input_data_handle = + builder_.ConstantR4FromArray4D(input_array); + + Padding padding = Padding::kSame; + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); +} + +// TODO(b/32173947): Test support for arbitrary-sized padding. +TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { + Array4D input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout + input_array.FillRandom(2.0f); + + int64 rank = 4; + int win_len = 3; + int win_stride = 2; + + const auto input_data_handle = + builder_.ConstantR4FromArray4D(input_array); + + Padding padding = Padding::kSame; + // Reduce only along the x and y dimensions, according to the win_len. + // Create padding vector with large padding values in the reduction dims. + std::vector> low_high_padding; + low_high_padding.resize(rank, {4, 4}); + + builder_.ReduceWindowWithGeneralPadding( + input_data_handle, builder_.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder_), {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, low_high_padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); +} +// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. +TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) { + Array4D input_array(2, 2, 4, 16); + + Array2D yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f}, + {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}, + {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f}, + {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, + 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}}); + input_array.FillWithYX(yx); + + int win_len = 2; + int win_stride = 2; + const auto input = builder_.ConstantR4FromArray4D(input_array); + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {1, 1, win_len, win_len}, + {1, 1, win_stride, win_stride}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, win_len}, + {1, 1, win_stride, win_stride}, padding); + ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. +TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmallOverlapped) { + constexpr int64 p = 2; + constexpr int64 z = 2; + constexpr int64 y = 4; + constexpr int64 x = 16; + Array4D input_array(p, z, y, x); + + Array2D yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f}, + {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}, + {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f}, + {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, + 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}}); + input_array.FillWithYX(yx); + + int win_len = 4; + int win_stride = 2; + const auto input = builder_.ConstantR4FromArray4D(input_array); + ReduceWindowAdd(input, {1, 1, win_len, win_len}, + {1, 1, win_stride, win_stride}, Padding::kValid); + + // Expected result + Array2D yx_result({{408.f, 440.f, 472.f, 504.f, 536.f, 568.f, 600.f}}); + Array4D expected(p, z, 1, 7); + expected.FillWithYX(yx_result); + ComputeAndCompareR4(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, MaxTrivial) { + const auto input = builder_.ConstantR1({42}); + ReduceWindowMax(input, {1}, {1}, Padding::kValid); + ComputeAndCompareR1(&builder_, {42}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add3In3) { + const auto input = builder_.ConstantR1({20, 100, 3}); + ReduceWindowAdd(input, {3}, {1}, Padding::kValid); + ComputeAndCompareR1(&builder_, {123}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add4In16Stride4) { + const auto input = builder_.ConstantR1( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + ReduceWindowAdd(input, {4}, {4}, Padding::kValid); + ComputeAndCompareR1(&builder_, {10, 26, 42, 58}, {}, + ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, DISABLED_ON_CPU(DISABLED_ON_GPU(Min3In5Stride2))) { + const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); + ReduceWindowMin(input, {3}, {2}, Padding::kValid); + ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Max3In3) { + const auto input = builder_.ConstantR1({20, 100, 3}); + ReduceWindowMax(input, {3}, {1}, Padding::kValid); + ComputeAndCompareR1(&builder_, {100}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add2In3) { + const auto input = builder_.ConstantR1({100, 10, 1}); + ReduceWindowAdd(input, {2}, {1}, Padding::kValid); + ComputeAndCompareR1(&builder_, {110, 11}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add3In5Stride2) { + const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); + ReduceWindowAdd(input, {3}, {2}, Padding::kValid); + ComputeAndCompareR1(&builder_, {11100, 111}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Max4In16Stride4) { + const auto input = builder_.ConstantR1( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + ReduceWindowMax(input, {4}, {4}, Padding::kValid); + ComputeAndCompareR1(&builder_, {4, 8, 12, 16}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Max4In16Stride3) { + const auto input = builder_.ConstantR1( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + ReduceWindowMax(input, {4}, {3}, Padding::kValid); + ComputeAndCompareR1(&builder_, {4, 7, 10, 13, 16}, {}, + ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Max4In16Stride8) { + const auto input = builder_.ConstantR1( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + ReduceWindowMax(input, {4}, {8}, Padding::kValid); + ComputeAndCompareR1(&builder_, {4, 12}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Max3In5Stride2) { + const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); + ReduceWindowMax(input, {3}, {2}, Padding::kValid); + ComputeAndCompareR1(&builder_, {10000, 100}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Max3In5Stride1) { + const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 101}); + ReduceWindowMax(input, {3}, {1}, Padding::kValid); + ComputeAndCompareR1(&builder_, {10000, 1000, 101}, {}, + ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add3In4Stride2) { + const auto input = builder_.ConstantR1({1000, 100, 10, 1}); + ReduceWindowAdd(input, {3}, {2}, Padding::kValid); + ComputeAndCompareR1(&builder_, {1110}, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add2In3SamePad) { + const auto input = builder_.ConstantR1({100, 10, 1}); + ReduceWindowAdd(input, {2}, {1}, Padding::kSame); + ComputeAndCompareR1(&builder_, {110, 11, 1}, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add3In3SamePad) { + const auto input = builder_.ConstantR1({100, 10, 1}); + ReduceWindowAdd(input, {3}, {1}, Padding::kSame); + ComputeAndCompareR1(&builder_, {110, 111, 11}, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add3In3Stride3SamePad) { + const auto input = builder_.ConstantR1({100, 10, 1}); + ReduceWindowAdd(input, {3}, {2}, Padding::kSame); + ComputeAndCompareR1(&builder_, {110, 11}, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add2x2In2x2Overlapped) { + Array2D input_array({{1.2f, -2.5f, 0.9f, 1.0f}, + {3.7f, 0.2f, -1.0f, -0.2f}, + {-0.4f, 2.7f, 1.1f, 2.2f}, + {0.6f, 1.7f, 1.4f, -0.2f}}); + auto input = builder_.ConstantR2FromArray2D(input_array); + ReduceWindowAdd(input, {2, 2}, {1, 1}, Padding::kValid); + Array2D expected( + {{2.6f, -2.4f, 0.7f}, {6.2f, 3.0f, 2.1f}, {4.6f, 6.9f, 4.5f}}); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) { + Array2D input_array({{1.2f, -2.5f, 0.9f, 1.0f}, + {3.7f, 0.2f, -1.0f, -0.2f}, + {-0.4f, 2.7f, 1.1f, 2.2f}, + {0.6f, 1.7f, 1.4f, -0.2f}}); + auto input = builder_.ConstantR2FromArray2D(input_array); + ReduceWindowAdd(input, {2, 2}, {2, 2}, Padding::kValid); + Array2D expected({ + {2.6f, 0.7f}, {4.6f, 4.5f}, + }); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { + Array3D input_array(2, 1, 2); + input_array(0, 0, 0) = 1000; + input_array(0, 0, 1) = 100; + input_array(1, 0, 0) = 10; + input_array(1, 0, 1) = 1; + auto input = builder_.ConstantR3FromArray3D(input_array); + + ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid); + + Array3D expected(2, 1, 1); + expected(0, 0, 0) = 1100; + expected(1, 0, 0) = 11; + ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { + Array3D input_array(2, 1, 3); + input_array(0, 0, 0) = 100; + input_array(0, 0, 1) = 10; + input_array(0, 0, 2) = 1; + input_array(1, 0, 0) = 500; + input_array(1, 0, 1) = 50; + input_array(1, 0, 2) = 5; + auto input = builder_.ConstantR3FromArray3D(input_array); + + ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid); + + Array3D expected(2, 1, 1); + expected(0, 0, 0) = 110; + expected(1, 0, 0) = 550; + ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { + Array3D input_array(2, 1, 3); + input_array(0, 0, 0) = 100; + input_array(0, 0, 1) = 10; + input_array(0, 0, 2) = 1; + input_array(1, 0, 0) = 500; + input_array(1, 0, 1) = 50; + input_array(1, 0, 2) = 5; + auto input = builder_.ConstantR3FromArray3D(input_array); + + ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame); + + Array3D expected(2, 1, 3); + expected(0, 0, 0) = 110; + expected(0, 0, 1) = 11; + expected(0, 0, 2) = 1; + expected(1, 0, 0) = 550; + expected(1, 0, 1) = 55; + expected(1, 0, 2) = 5; + ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc new file mode 100644 index 0000000000..802087b508 --- /dev/null +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ReplayTest : public ClientLibraryTestBase {}; + +TEST_F(ReplayTest, TwoPlusTwoReplay) { + // Make 2+2 computation. + ComputationBuilder builder(client_, TestName()); + auto two = builder.ConstantR0(2); + builder.Add(two, two); + Computation computation = builder.Build().ConsumeValueOrDie(); + + // Serialize it out. + std::unique_ptr module = + computation.Snapshot().ConsumeValueOrDie(); + + // Replay it. + Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + + // Check signature is the same. + std::unique_ptr original_shape = + client_->GetComputationShape(computation).ConsumeValueOrDie(); + std::unique_ptr replayed_shape = + client_->GetComputationShape(replayed).ConsumeValueOrDie(); + ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + + // Run it. + std::unique_ptr literal = + client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + .ConsumeValueOrDie(); + + // Expect 4. + LiteralTestUtil::ExpectR0Equal(4, *literal); +} + +XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { + // Make computation. + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); + builder.Add(x, y); + Computation computation = builder.Build().ConsumeValueOrDie(); + + // Serialize it out. + std::unique_ptr module = + computation.Snapshot().ConsumeValueOrDie(); + + // Replay it. + Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + + // Check signature is the same. + std::unique_ptr original_shape = + client_->GetComputationShape(computation).ConsumeValueOrDie(); + std::unique_ptr replayed_shape = + client_->GetComputationShape(replayed).ConsumeValueOrDie(); + ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + + // Run it. + std::unique_ptr x_data = + client_->TransferToServer(*LiteralUtil::CreateR0(2)) + .ConsumeValueOrDie(); + std::unique_ptr y_data = + client_->TransferToServer(*LiteralUtil::CreateR0(3)) + .ConsumeValueOrDie(); + std::unique_ptr literal = + client_ + ->ExecuteAndTransfer(replayed, + /*arguments=*/{x_data.get(), y_data.get()}) + .ConsumeValueOrDie(); + + // Expect 5. + LiteralTestUtil::ExpectR0Equal(5, *literal); +} + +TEST_F(ReplayTest, MapPlusTwoOverR1) { + // As above, but with map(+2) over some constant array. + ComputationBuilder plus_two_builder(client_, "plus two"); + auto input = + plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); + plus_two_builder.Add(input, plus_two_builder.ConstantR0(2)); + Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); + + ComputationBuilder mapper_builder(client_, TestName()); + auto original = mapper_builder.ConstantR1({1, 2, 3}); + mapper_builder.Map({original}, plus_two); + + Computation computation = mapper_builder.Build().ConsumeValueOrDie(); + + // Serialize it out. + std::unique_ptr module = + computation.Snapshot().ConsumeValueOrDie(); + + // Replay it. + Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + + // Check signature is the same. + std::unique_ptr original_shape = + client_->GetComputationShape(computation).ConsumeValueOrDie(); + std::unique_ptr replayed_shape = + client_->GetComputationShape(replayed).ConsumeValueOrDie(); + ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + + // Destroy the originals. + computation.Reset(); + plus_two.Reset(); + + // Run it. + std::unique_ptr literal = + client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + .ConsumeValueOrDie(); + + // Expect result. + LiteralTestUtil::ExpectR1Equal({3, 4, 5}, *literal); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc new file mode 100644 index 0000000000..ce309eb743 --- /dev/null +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ReshapeMotionTest = ClientLibraryTestBase; + +TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{2, 3, 5}, {7, 11, 13}}); + auto b = builder.ConstantR2({{17, 19}, {23, 29}, {31, 37}}); + auto c = builder.Reshape(a, {6}); + auto d = builder.Reshape(b, {6}); + auto e = builder.Mul(c, d); + + ComputeAndCompareR1(&builder, {34, 57, 115, 203, 341, 481}, {}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc new file mode 100644 index 0000000000..a9159d39ca --- /dev/null +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -0,0 +1,811 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ReshapeTest : public ClientLibraryTestBase { + public: + ErrorSpec zero_error_spec_{0.0}; +}; + +// Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. +XLA_TEST_F(ReshapeTest, Trivial1x1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{1.0}}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + +// Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. +XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{1.0}}); + auto reshape = + builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); + auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, Trivial0x3) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, Trivial3x0) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); +} + +// Collapses a 2-dimensional row vector to 1 dimension. +XLA_TEST_F(ReshapeTest, Trivial1x3) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, + zero_error_spec_); +} + +// Collapses a 2-dimensional column vector to 1 dimension. +XLA_TEST_F(ReshapeTest, Trivial3x1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{1.0f}, {2.0f}, {3.0f}}); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, + zero_error_spec_); +} + +// Splits an empty vector into an empty matrix. +XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto result = + builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); + ComputeAndCompareR2(&builder, Array2D(2, 0), {}, + zero_error_spec_); +} + +// Splits a vector into a matrix. +XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + auto result = + builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); + Array2D expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + ComputeAndCompareR2(&builder, expected_2x3, {}, zero_error_spec_); +} + +// Transposes a 2x0 array to a 0x2 array. +XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + + ComputeAndCompareR2(&builder, Array2D(2, 0), {}, + zero_error_spec_); +} + +// Transposes a 2-dimensional row vector to a column vector. +XLA_TEST_F(ReshapeTest, ReshapeRowToCol) { + ComputationBuilder builder(client_, TestName()); + auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); + auto a = builder.ConstantR2FromArray2D(*simple); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); + + auto expected = ReferenceUtil::TransposeArray2D(*simple); + ComputeAndCompareR2(&builder, *expected, {}, zero_error_spec_); +} + +// Transposes a 2-dimensional array. +XLA_TEST_F(ReshapeTest, TransposeAsReshape) { + ComputationBuilder builder(client_, TestName()); + auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); + auto a = builder.ConstantR2FromArray2D(*a4x3); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); + + auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); + ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); +} + +// Transposes a 0x4 array with ComputationBuilder::Trans. +XLA_TEST_F(ReshapeTest, Transpose0x4) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(0, 4)); + auto result = builder.Transpose(a, {1, 0}); + + ComputeAndCompareR2(&builder, Array2D(4, 0), {}, + zero_error_spec_); +} + +// Transposes a 2-dimensional array with ComputationBuilder::Trans. +XLA_TEST_F(ReshapeTest, Transpose4x3) { + ComputationBuilder builder(client_, TestName()); + auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); + auto a = builder.ConstantR2FromArray2D(*a4x3); + auto result = builder.Transpose(a, {1, 0}); + + auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); + ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); +} + +// Reshapes an empty 2-dimensional array with dimensions that are not just a +// rearrangement of the originals (split), but no reordering (no shuffle). +XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(6, 0)); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + + ComputeAndCompareR4(&builder, Array4D(2, 3, 0, 0), {}, + zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR4FromArray4D(Array4D(2, 3, 4, 0)); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + + ComputeAndCompareR2(&builder, Array2D(24, 0), {}, + zero_error_spec_); +} + +// Reshapes a 2-dimensional array with dimensions that are not just a +// rearrangement of the originals (split), but no reordering (no shuffle). +XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) { + ComputationBuilder builder(client_, TestName()); + auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); + auto a = builder.ConstantR2FromArray2D(*a4x3); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); + + auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); + ComputeAndCompareR2(&builder, *expected2x6, {}, zero_error_spec_); +} + +// Reshapes a 2-dimensional array with dimensions that are not just a +// rearrangement of the originals (split), and reorder the input (shuffle). +XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(Array2D(0, 6)); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + + ComputeAndCompareR2(&builder, Array2D(3, 0), {}, + zero_error_spec_); +} + +// Reshapes a 2-dimensional array with dimensions that are not just a +// rearrangement of the originals (split), and reorder the input (shuffle). +XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) { + ComputationBuilder builder(client_, TestName()); + auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); + auto a = builder.ConstantR2FromArray2D(*a4x3); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); + + Array2D expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, + {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); + ComputeAndCompareR2(&builder, expected2x6, {}, zero_error_spec_); +} + +// The following tests use the same input 3D array; they test the examples we +// show for the Reshape operation in the operation_semantics document. +// TODO(eliben): find a way to show this code in the documentation without +// duplication. +Array3D v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}}, + {{20, 21, 22}, {25, 26, 27}}, + {{30, 31, 32}, {35, 36, 37}}, + {{40, 41, 42}, {45, 46, 47}}}); + +XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); + auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + ComputeAndCompareR1(&builder, + {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, + 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}, + {}); +} + +XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); + auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + Array2D expected({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); + auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + ComputeAndCompareR1(&builder, + {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, + 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}, + {}); +} + +XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); + auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + Array2D expected({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); + auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + Array3D expected( + {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, + {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); + ComputeAndCompareR3(&builder, expected, {}); +} + +// Collapses the low dimensions of a 4D tensor to get a 2D matrix, without +// reordering dimensions (for NeuralNet::FullyConnected). +// +// First we create a tesseract raster-face like: +// +// 1 2 3 +// 4 5 6 +// +// First we collapse Y and X within the raster space yielding: +// +// 1 2 3 4 5 6 +// +// Then we collapse Z be collapsed so we just end up with planes: +// +// 1 2 3 4 5 6 1 2 3 4 5 6 +XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) { + ComputationBuilder builder(client_, TestName()); + Array4D t2x2x2x3(2, 2, 2, 3); + auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); + t2x2x2x3.FillWithYX(*filler2x3); + auto a = builder.ConstantR4FromArray4D(t2x2x2x3); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3}); + + Array2D expected2x12( + {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f}}); + ComputeAndCompareR2(&builder, expected2x12, {}, zero_error_spec_); +} + +// As above, but uses reshape directly. +XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { + ComputationBuilder builder(client_, TestName()); + Array4D t(2, 1, 2, 2); + t(0, 0, 0, 0) = 0; + t(0, 0, 0, 1) = 1; + t(0, 0, 1, 0) = 2; + t(0, 0, 1, 1) = 3; + t(1, 0, 0, 0) = 4; + t(1, 0, 0, 1) = 5; + t(1, 0, 1, 0) = 6; + t(1, 0, 1, 1) = 7; + auto a = builder.ConstantR4FromArray4D(t); + auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); + + Array2D expected({{0, 1, 2, 3}, {4, 5, 6, 7}}); + ComputeAndCompareR2(&builder, expected, {}, zero_error_spec_); +} + +// Reshape various ranks to a scalar. +XLA_TEST_F(ReshapeTest, ToScalar) { + for (int rank = 0; rank < 8; ++rank) { + ComputationBuilder b(client_, TestName()); + auto input = LiteralUtil::CreateR1({83.0f}); + std::vector ones(rank, 1); // this is {1, ..., 1}. + std::vector dimensions(rank); + std::iota(dimensions.begin(), dimensions.end(), 0); + *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones); + b.Reshape(b.ConstantLiteral(*input), dimensions, {}); + + ComputeAndCompareR0(&b, 83.0f, {}, zero_error_spec_); + } +} + +XLA_TEST_F(ReshapeTest, BadDimensions) { + ComputationBuilder b(client_, TestName()); + b.Reshape(b.ConstantR1({1}), {}, {}); + EXPECT_MATCH(ExecuteToString(&b, {}), + testing::HasSubstr("dimensions not a permutation")); +} + +XLA_TEST_F(ReshapeTest, BadNewSizes) { + ComputationBuilder b(client_, TestName()); + b.Reshape(b.ConstantR1({1, 2}), {1}, {}); + EXPECT_MATCH(ExecuteToString(&b, {}), + testing::HasSubstr("mismatched element counts")); +} + +XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { + const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, parameter_shape, "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + + // clang-format off + auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(Array4D{ + { + { + {0, 1}, + {2, 3}, + }, + { + {100, 101}, + {102, 103}, + }, + }, + { + { + {222, 333}, + {444, 555}, + }, + { + {666, 777}, + {888, 999}, + }, + }, + }, + LayoutUtil::MakeLayout({0, 1, 2, 3})); + // clang-format on + std::unique_ptr input = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + Array2D expected_array({ + {0, 1, 2, 3, 100, 101, 102, 103}, + {222, 333, 444, 555, 666, 777, 888, 999}, + }); + + Computation computation = builder.Build().ConsumeValueOrDie(); + const Shape shape_with_output_layout = + ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); + std::unique_ptr actual = + client_ + ->ExecuteAndTransfer(computation, {input.get()}, + &shape_with_output_layout) + .ConsumeValueOrDie(); + std::unique_ptr expected = + LiteralUtil::CreateR2FromArray2D(expected_array); + LiteralTestUtil::ExpectEqual(*expected, *actual); +} + +XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { + std::unique_ptr input = LiteralUtil::CreateR2({ + {0, 1, 2, 3, 4, 5, 6, 7}, + {100, 101, 102, 103, 104, 105, 106, 107}, + {200, 201, 202, 203, 204, 205, 206, 207}, + }); + std::unique_ptr input_data = + client_->TransferToServer(*input).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + + // clang-format off + Array4D expected = { + {{{0, 1, 2, 3}}, + {{4, 5, 6, 7}}}, + {{{100, 101, 102, 103}}, + {{104, 105, 106, 107}}}, + {{{200, 201, 202, 203}}, + {{204, 205, 206, 207}}} + }; + // clang-format on + ComputeAndCompareR4(&builder, expected, {input_data.get()}, + zero_error_spec_); +} + +// Tests R2->R4 reshape with the reshape dimensions {1, 0}. +XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { + std::unique_ptr input = LiteralUtil::CreateR2({ + {0, 1, 2, 3, 4, 5, 6, 7}, + {100, 101, 102, 103, 104, 105, 106, 107}, + {200, 201, 202, 203, 204, 205, 206, 207}, + }); + std::unique_ptr input_data = + client_->TransferToServer(*input).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + + // clang-format off + Array4D expected = { + {{{0, 100, 200, 1}}, + {{101, 201, 2, 102}}}, + {{{202, 3, 103, 203}}, + {{4, 104, 204, 5}}}, + {{{105, 205, 6, 106}}, + {{206, 7, 107, 207}}} + }; + // clang-format on + ComputeAndCompareR4(&builder, expected, {input_data.get()}, + zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + Array4D input(2, 1, 1, 1); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + + std::unique_ptr expected = + LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + Array4D input(2, 1, 4, 1); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + + std::unique_ptr expected = + LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); +} + +// Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. +XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + Array4D input(5, 10, 2, 3); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); + + Array2D expected_array(5, 60); + input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { + expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = + *cell; + }); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); +} + +XLA_TEST_F(ReshapeTest, NoopReshape) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + Array4D input_array(2, 3, 5, 7); + input_array.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2}, + /*new_sizes=*/{7, 2, 3, 5}); + Computation computation = builder.Build().ConsumeValueOrDie(); + + const Shape output_shape_with_layout = + ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); + std::unique_ptr output_literal = + client_ + ->ExecuteAndTransfer(computation, {input_data.get()}, + &output_shape_with_layout) + .ConsumeValueOrDie(); + + // Since the reshape is a no-op, verify that it does not change the underlying + // data. + EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), + tensorflow::gtl::ArraySlice(output_literal->f32s())); +} + +XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { + auto literal_1x2x3x4 = LiteralUtil::CreateR4( + {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, + {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantLiteral(*literal_1x2x3x4); + builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{1, 2, 3, 4}); + + ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {}); +} + +XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { + auto literal_1x2x3x4 = LiteralUtil::CreateR4( + {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, + {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); + + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantLiteral(*literal_1x2x3x4); + builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0}, + /*new_sizes=*/{2, 4, 3, 1}); + + // clang-format off + auto expected_2x4x3x1 = LiteralUtil::CreateR4( + {{{{1}, {5}, {9}}, + {{2}, {6}, {10}}, + {{3}, {7}, {11}}, + {{4}, {8}, {12}}}, + {{{13}, {17}, {21}}, + {{14}, {18}, {22}}, + {{15}, {19}, {23}}, + {{16}, {20}, {24}}}}); + // clang-format on + + ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {}); +} + +XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector bounds = {2, 2, 2, 2}; + std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; + Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + + std::unique_ptr expected = LiteralUtil::Relayout( + *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), + LayoutUtil::MakeLayout({3, 2, 1, 0})); + + // Specify the requested output shape explicitly to ensure that this reshape + // actually corresponds to a two minor transpose. + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_, &expected->shape()); +} + +XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector bounds = {1, 1, 250, 300}; + std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; + Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + + std::unique_ptr expected = LiteralUtil::Relayout( + *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), + LayoutUtil::MakeLayout({3, 2, 1, 0})); + + // Specify the requested output shape explicitly to ensure that this reshape + // actually corresponds to a two minor transpose. + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_, &expected->shape()); +} + +XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector bounds = {5, 5, 1, 10}; + std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; + Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + + std::unique_ptr expected = LiteralUtil::Relayout( + *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), + LayoutUtil::MakeLayout({3, 2, 1, 0})); + + // Specify the requested output shape explicitly to ensure that this reshape + // actually corresponds to a two minor transpose. + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_, &expected->shape()); +} + +XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + // This happens in NN-Builder MNIST. + std::vector bounds = {5, 5, 10, 1}; + std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; + Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + + std::unique_ptr expected = LiteralUtil::Relayout( + *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), + LayoutUtil::MakeLayout({3, 2, 1, 0})); + + // Specify the requested output shape explicitly to ensure that this reshape + // actually corresponds to a two minor transpose. + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_, &expected->shape()); +} + +XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { + std::mt19937 rng; + std::uniform_real_distribution distribution; + std::vector bounds = {3, 3, 1, 3}; + std::vector new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]}; + Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); + input.Each( + [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + float* cell) { *cell = distribution(rng); }); + std::unique_ptr input_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({0, 1, 2, 3})); + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.Parameter(0, input_literal->shape(), "a"); + builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); + + std::unique_ptr expected = LiteralUtil::Relayout( + *LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal), + input_literal->shape().layout()); + + // Specify the requested output shape explicitly to ensure that this reshape + // actually corresponds to a two minor transpose. + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_, &expected->shape()); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc new file mode 100644 index 0000000000..63dd4421fa --- /dev/null +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ReverseTest : public ClientLibraryTestBase {}; + +// Tests the reverse operation on a scalar. +XLA_TEST_F(ReverseTest, ReverseScalar) { + ComputationBuilder b(client_, TestName()); + float input = 3.5f; + b.Rev(b.ConstantR0(input), {}); + ComputeAndCompareR0(&b, input, {}); +} + +// Tests the reverse operation on a 0x0 float array on both dimensions. +XLA_TEST_F(ReverseTest, Reverse0x0FloatArray) { + ComputationBuilder b(client_, TestName()); + b.Rev(b.ConstantR2FromArray2D(Array2D(0, 0)), {0, 1}); + ComputeAndCompareR2(&b, Array2D(0, 0), {}); +} + +// Tests the reverse operation on a 0x1 float array on both dimensions. +XLA_TEST_F(ReverseTest, Reverse0x1FloatArray) { + ComputationBuilder b(client_, TestName()); + b.Rev(b.ConstantR2FromArray2D(Array2D(0, 1)), {0, 1}); + ComputeAndCompareR2(&b, Array2D(0, 1), {}); +} + +// Tests the reverse operation on a 1x0 float array on both dimensions. +XLA_TEST_F(ReverseTest, Reverse1x0FloatArray) { + ComputationBuilder b(client_, TestName()); + b.Rev(b.ConstantR2FromArray2D(Array2D(1, 0)), {0, 1}); + ComputeAndCompareR2(&b, Array2D(1, 0), {}); +} + +// Tests the reverse operation on a 1x1 float array on both dimensions. +XLA_TEST_F(ReverseTest, Reverse1x1FloatArray) { + ComputationBuilder b(client_, TestName()); + Array2D input({{3.5f}}); + b.Rev(b.ConstantR2FromArray2D(input), {0, 1}); + ComputeAndCompareR2(&b, input, {}); +} + +XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim02) { + ComputationBuilder b(client_, TestName()); + b.Rev(b.ConstantR4FromArray4D(Array4D(2, 0, 4, 3)), {0, 2}); + ComputeAndCompareR4(&b, Array4D(2, 0, 4, 3), {}); +} + +XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim13) { + ComputationBuilder b(client_, TestName()); + b.Rev(b.ConstantR4FromArray4D(Array4D(2, 0, 4, 3)), {1, 3}); + ComputeAndCompareR4(&b, Array4D(2, 0, 4, 3), {}); +} + +// Tests the reverse operation on a 4D U8 array on dimension 0 and 3. +XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { + ComputationBuilder b(client_, TestName()); + // Input shape is U8[1x2x3x4]. + // clang-format off + Array4D input({{ + {{1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}}, + {{13, 14, 15, 16}, + {17, 18, 19, 20}, + {21, 22, 23, 24}}, + }}); + // clang-format on + + b.Rev(b.ConstantR4FromArray4D(input), {0, 3}); + + // clang-format off + Array4D expected({{ + {{4, 3, 2, 1}, + {8, 7, 6, 5}, + {12, 11, 10, 9}}, + {{16, 15, 14, 13}, + {20, 19, 18, 17}, + {24, 23, 22, 21}}, + }}); + // clang-format on + ComputeAndCompareR4(&b, expected, {}); +} + +// Tests the reverse operation on a 4D float array on dimension 0 and 1. +TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { + ComputationBuilder b(client_, TestName()); + // Input shape is float[4x3x2x1]. + // clang-format off + Array4D input({ + {{{1.0f}, {2.0f}}, + {{3.0f}, {4.0f}}, + {{5.0f}, {6.0f}}}, + {{{7.0f}, {8.0f}}, + {{9.0f}, {10.0f}}, + {{11.0f}, {12.0f}}}, + {{{13.0f}, {14.0f}}, + {{15.0f}, {16.0f}}, + {{17.0f}, {18.0f}}}, + {{{19.0f}, {20.0f}}, + {{21.0f}, {22.0f}}, + {{23.0f}, {24.0f}}}, + }); + // clang-format on + + b.Rev(b.ConstantR4FromArray4D(input), {0, 1}); + + // clang-format off + Array4D expected({ + {{{23.0f}, {24.0f}}, + {{21.0f}, {22.0f}}, + {{19.0f}, {20.0f}}}, + {{{17.0f}, {18.0f}}, + {{15.0f}, {16.0f}}, + {{13.0f}, {14.0f}}}, + {{{11.0f}, {12.0f}}, + {{9.0f}, {10.0f}}, + {{7.0f}, {8.0f}}}, + {{{5.0f}, {6.0f}}, + {{3.0f}, {4.0f}}, + {{1.0f}, {2.0f}}}, + }); + // clang-format on + ComputeAndCompareR4(&b, expected, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc new file mode 100644 index 0000000000..5b734c0f40 --- /dev/null +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/packed_literal_reader.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class RoundTripPackedLiteralTest : public ClientLibraryTestBase { + protected: + // Sends the literal to the server and retrieves it back. + std::unique_ptr RoundTripToServer(const Literal& original) { + std::unique_ptr data = + client_->TransferToServer(original).ConsumeValueOrDie(); + return client_->Transfer(*data).ConsumeValueOrDie(); + } +}; + +TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { + string data(sizeof(float) * 2, 0); + tensorflow::gtl::MutableArraySlice floats( + tensorflow::bit_cast(data.data()), 2); + floats[0] = 42.0; + floats[1] = 24.0; + + string fname = tensorflow::testing::TmpDir() + "/RoundTripsR1F32Length2.data"; + EXPECT_TRUE( + tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data) + .ok()); + + std::unique_ptr f; + TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); + PackedLiteralReader reader(f.release()); + std::unique_ptr actual = + reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); + EXPECT_TRUE(reader.IsExhausted()); + + EXPECT_EQ(42.0, LiteralUtil::Get(*actual, {0})); + EXPECT_EQ(24.0, LiteralUtil::Get(*actual, {1})); +} + +TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { + string data(sizeof(float) * 4, 0); + tensorflow::gtl::MutableArraySlice floats( + tensorflow::bit_cast(data.data()), 4); + // With x as the minor dimension, these will become: + floats[0] = 42.0; // y=0,x=0 + floats[1] = 24.0; // y=0,x=1 + floats[2] = 64.0; // y=1,x=0 + floats[3] = 46.0; // y=1,x=1 + + string fname = + tensorflow::testing::TmpDir() + "/RoundTripsR2F32Size2x2Dim0Minor.data"; + EXPECT_TRUE( + tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data) + .ok()); + + const Layout layout = LayoutUtil::MakeLayout({1, 0}); + + std::unique_ptr f; + TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); + PackedLiteralReader reader(f.release()); + std::unique_ptr actual = + reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); + EXPECT_TRUE(reader.IsExhausted()); + + EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); + EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {0, 1})); + EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {1, 0})); + EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + + std::unique_ptr round_tripped = RoundTripToServer(*actual); + LiteralTestUtil::ExpectEqual(*round_tripped, *actual); +} + +TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { + string data(sizeof(float) * 4, 0); + tensorflow::gtl::MutableArraySlice floats( + tensorflow::bit_cast(data.data()), 4); + // With y as the minor dimension, these will become: + floats[0] = 42.0; // y=0,x=0 + floats[1] = 24.0; // y=1,x=0 + floats[2] = 64.0; // y=0,x=1 + floats[3] = 46.0; // y=1,x=1 + + string fname = + tensorflow::testing::TmpDir() + "/RoundTripsR2F32Size2x2Dim1Minor.data"; + EXPECT_TRUE( + tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data) + .ok()); + + const Layout layout = LayoutUtil::MakeLayout({0, 1}); + + std::unique_ptr f; + TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); + PackedLiteralReader reader(f.release()); + std::unique_ptr actual = + reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); + EXPECT_TRUE(reader.IsExhausted()); + + EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); + EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {1, 0})); + EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {0, 1})); + EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + + std::unique_ptr round_tripped = RoundTripToServer(*actual); + LiteralTestUtil::ExpectEqual(*round_tripped, *actual); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc new file mode 100644 index 0000000000..04a8bab0eb --- /dev/null +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests transferring literals of various shapes and values in and out of the +// XLA service. + +#include +#include +#include + +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class RoundTripTransferTest : public ClientLibraryTestBase { + protected: + void RoundTripTest(const Literal& original) { + std::unique_ptr data = + client_->TransferToServer(original).ConsumeValueOrDie(); + std::unique_ptr result = + client_->Transfer(*data).ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(original, *result); + } +}; + +TEST_F(RoundTripTransferTest, R0S32) { + RoundTripTest(*LiteralUtil::CreateR0(42)); +} + +TEST_F(RoundTripTransferTest, R0F32) { + RoundTripTest(*LiteralUtil::CreateR0(42.0)); +} + +TEST_F(RoundTripTransferTest, R1F32_Len0) { + RoundTripTest(*LiteralUtil::CreateR1({})); +} + +TEST_F(RoundTripTransferTest, R1F32_Len2) { + RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); +} + +TEST_F(RoundTripTransferTest, R1F32_Len256) { + std::vector values(256); + std::iota(values.begin(), values.end(), 1.0); + RoundTripTest(*LiteralUtil::CreateR1(values)); +} + +TEST_F(RoundTripTransferTest, R1F32_Len1024) { + std::vector values(1024); + std::iota(values.begin(), values.end(), 1.0); + RoundTripTest(*LiteralUtil::CreateR1(values)); +} + +TEST_F(RoundTripTransferTest, R1F32_Len1025) { + std::vector values(1025); + std::iota(values.begin(), values.end(), 1.0); + RoundTripTest(*LiteralUtil::CreateR1(values)); +} + +TEST_F(RoundTripTransferTest, R1F32_Len4096) { + std::vector values(4096); + std::iota(values.begin(), values.end(), 1.0); + RoundTripTest(*LiteralUtil::CreateR1(values)); +} + +TEST_F(RoundTripTransferTest, R2F32_Len10x0) { + RoundTripTest( + *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); +} + +TEST_F(RoundTripTransferTest, R2F32_Len2x2) { + RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); +} + +TEST_F(RoundTripTransferTest, R3F32) { + RoundTripTest( + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); +} + +TEST_F(RoundTripTransferTest, R4F32) { + RoundTripTest(*LiteralUtil::CreateR4({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }})); +} + +TEST_F(RoundTripTransferTest, EmptyTuple) { + RoundTripTest(*LiteralUtil::MakeTuple({})); +} + +TEST_F(RoundTripTransferTest, TupleOfR1F32) { + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); +} + +TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); +} + +TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), + LiteralUtil::CreateR1({2, 3}).get()})); +} + +// Below two tests are added to identify the cost of large data transfers. +TEST_F(RoundTripTransferTest, R2F32_Large) { + RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); +} + +TEST_F(RoundTripTransferTest, R4F32_Large) { + Array4D array4d(2, 2, 256, 256); + array4d.FillWithMultiples(1.0f); + RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc new file mode 100644 index 0000000000..bd9cae4d1d --- /dev/null +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -0,0 +1,630 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ScalarComputationsTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; + + protected: + // A template for building and running a binary comparison test. + template + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + ComputationDataHandle (ComputationBuilder::*op)( + const ComputationDataHandle&, + const ComputationDataHandle&, + tensorflow::gtl::ArraySlice)) { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle lhs_op = builder.ConstantR0(lhs); + ComputationDataHandle rhs_op = builder.ConstantR0(rhs); + ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + ComputeAndCompareR0(&builder, expected, {}); + } + + template + void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, + ComputationDataHandle (ComputationBuilder::*op)( + const ComputationDataHandle&, + const ComputationDataHandle&, + tensorflow::gtl::ArraySlice)) { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle lhs_op = builder.ConstantR0(lhs); + ComputationDataHandle rhs_op = builder.ConstantR0(rhs); + ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + ComputeAndCompareR0(&builder, expected, {}); + } +}; + +TEST_F(ScalarComputationsTest, NegateScalarF32) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(2.1f)); + + ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, NegateScalarS32) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(2)); + + ComputeAndCompareR0(&builder, -2, {}); +} + +TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { + ComputationBuilder builder(client_, TestName()); + builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + + ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { + ComputationBuilder builder(client_, TestName()); + builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); + + ComputeAndCompareR0(&builder, 7, {}); +} + +TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { + ComputationBuilder builder(client_, TestName()); + builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + + ComputeAndCompareR0(&builder, 92, {}); +} + +XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { + ComputationBuilder builder(client_, TestName()); + builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + + ComputeAndCompareR0(&builder, 92, {}); +} + +XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { + ComputationBuilder builder(client_, TestName()); + const uint64 a = static_cast(1) << 63; + const uint64 b = a + 1; + builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + + ComputeAndCompareR0(&builder, a + b, {}); +} + +XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { + ComputationBuilder builder(client_, TestName()); + const int64 a = static_cast(1) << 62; + const int64 b = a + 1; + builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + + ComputeAndCompareR0(&builder, a + b, {}); +} + +XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { + ComputationBuilder builder(client_, TestName()); + builder.Add(builder.ConstantR0(0.25), + builder.ConstantR0(3.5)); + + ComputeAndCompareR0(&builder, 3.75, {}); +} + +TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { + ComputationBuilder builder(client_, TestName()); + builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + + ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { + ComputationBuilder builder(client_, TestName()); + builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); + + ComputeAndCompareR0(&builder, -3, {}); +} + +TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { + ComputationBuilder builder(client_, TestName()); + builder.Mul(builder.Mul(builder.ConstantR0(2.1f), + builder.ConstantR0(5.5f)), + builder.ConstantR0(0.5f)); + + ComputeAndCompareR0(&builder, 5.775f, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { + std::vector data = {0, + 1, + -1, + 1234, + 0x1a243514, + std::numeric_limits::max(), + std::numeric_limits::min()}; + + for (int32 x : data) { + for (int32 y : data) { + ComputationBuilder builder(client_, TestName()); + builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + + // Signed integer overflow is undefined behavior in C++. Convert the input + // integers to unsigned, perform the multiplication unsigned, and convert + // back. + int32 expected = static_cast(x) * static_cast(y); + + ComputeAndCompareR0(&builder, expected, {}); + } + } +} + +TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { + std::vector data = {0, 1, 0xDEADBEEF, 1234, + 0x1a243514, 0xFFFFFFFF, 0x80808080}; + + for (uint32 x : data) { + for (uint32 y : data) { + ComputationBuilder builder(client_, TestName()); + builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + + uint32 expected = x * y; + ComputeAndCompareR0(&builder, expected, {}); + } + } +} + +TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { + ComputationBuilder builder(client_, TestName()); + builder.Mul( + builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), + builder.ConstantR0(1)); + + ComputeAndCompareR0(&builder, 10, {}); +} + +TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); + std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); + std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + std::unique_ptr b_data = + client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + std::unique_ptr c_data = + client_->TransferToServer(*c_literal).ConsumeValueOrDie(); + + ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a"); + ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b"); + ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c"); + builder.Mul(builder.Mul(a, b), c); + + ComputeAndCompareR0(&builder, 5.775f, + {a_data.get(), b_data.get(), c_data.get()}, + error_spec_); +} + +TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { + ComputationBuilder builder(client_, TestName()); + builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); + + ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); +} + +XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { + ComputationBuilder builder(client_, TestName()); + builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); + + ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); +} + +XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) { + ComputationBuilder builder(client_, TestName()); + builder.Div(builder.ConstantR0(-5), builder.ConstantR0(2)); + + ComputeAndCompareR0(&builder, -2, {}); +} + +TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) { + ComputationBuilder builder(client_, TestName()); + builder.Rem(builder.ConstantR0(-5), builder.ConstantR0(2)); + + ComputeAndCompareR0(&builder, -1, {}); +} + +TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) { + ComputationBuilder builder(client_, TestName()); + builder.Rem(builder.ConstantR0(INT_MIN), + builder.ConstantR0(7919)); + + ComputeAndCompareR0(&builder, -1309, {}); +} + +TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) { + ComputationBuilder builder(client_, TestName()); + builder.Rem(builder.ConstantR0(INT_MIN), + builder.ConstantR0(INT_MAX)); + + ComputeAndCompareR0(&builder, -1, {}); +} + +TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); + builder.Rem(x, builder.ConstantR0(80000)); + + std::unique_ptr literal = LiteralUtil::CreateR0(87919); + TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal)); + ComputeAndCompareR0(&builder, 7919, {input_data.get()}); +} + +XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { + ComputationBuilder builder(client_, TestName()); + // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 + // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). + builder.Div(builder.ConstantR0(0xFFFFFFFE), + builder.ConstantR0(2)); + + ComputeAndCompareR0(&builder, 0x7FFFFFFF, {}); +} + +TEST_F(ScalarComputationsTest, LogicalAnd) { + for (bool x : {false, true}) { + for (bool y : {false, true}) { + ComputationBuilder builder(client_, TestName()); + builder.LogicalAnd(builder.ConstantR0(x), + builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x && y, {}); + } + } +} + +TEST_F(ScalarComputationsTest, LogicalOr) { + for (bool x : {false, true}) { + for (bool y : {false, true}) { + ComputationBuilder builder(client_, TestName()); + builder.LogicalOr(builder.ConstantR0(x), + builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x || y, {}); + } + } +} + +TEST_F(ScalarComputationsTest, LogicalNot) { + for (bool x : {false, true}) { + ComputationBuilder builder(client_, TestName()); + builder.LogicalNot(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, !x, {}); + } +} + +TEST_F(ScalarComputationsTest, SelectScalarTrue) { + ComputationBuilder builder(client_, TestName()); + builder.Select(builder.ConstantR0(true), // The predicate. + builder.ConstantR0(123.0f), // The value on true. + builder.ConstantR0(42.0f)); // The value on false. + + ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, SelectScalarFalse) { + ComputationBuilder builder(client_, TestName()); + builder.Select(builder.ConstantR0(false), // The predicate. + builder.ConstantR0(123.0f), // The value on true. + builder.ConstantR0(42.0f)); // The value on false. + + ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); +} + +// This test is an explicit version of what is happening in the following +// templatized comparison tests. +TEST_F(ScalarComputationsTest, CompareGtScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); + + ComputeAndCompareR0(&builder, true, {}); +} + +// S32 comparisons. +TEST_F(ScalarComputationsTest, CompareEqS32Greater) { + TestCompare(2, 1, false, &ComputationBuilder::Eq); +} +TEST_F(ScalarComputationsTest, CompareEqS32Equal) { + TestCompare(3, 3, true, &ComputationBuilder::Eq); +} + +TEST_F(ScalarComputationsTest, CompareNeS32) { + TestCompare(2, 1, true, &ComputationBuilder::Ne); +} + +TEST_F(ScalarComputationsTest, CompareGeS32) { + TestCompare(2, 1, true, &ComputationBuilder::Ge); +} + +TEST_F(ScalarComputationsTest, CompareGtS32) { + TestCompare(1, 5, false, &ComputationBuilder::Gt); +} + +TEST_F(ScalarComputationsTest, CompareLeS32) { + TestCompare(2, 1, false, &ComputationBuilder::Le); +} + +TEST_F(ScalarComputationsTest, CompareLtS32) { + TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(std::numeric_limits::min(), + std::numeric_limits::max(), true, + &ComputationBuilder::Lt); +} + +// U32 comparisons. +TEST_F(ScalarComputationsTest, CompareEqU32False) { + TestCompare(2, 1, false, &ComputationBuilder::Eq); +} + +TEST_F(ScalarComputationsTest, CompareNeU32) { + TestCompare(2, 1, true, &ComputationBuilder::Ne); +} + +TEST_F(ScalarComputationsTest, CompareGeU32Greater) { + TestCompare(2, 1, true, &ComputationBuilder::Ge); +} + +TEST_F(ScalarComputationsTest, CompareGeU32Equal) { + TestCompare(3, 3, true, &ComputationBuilder::Ge); +} + +TEST_F(ScalarComputationsTest, CompareGtU32) { + TestCompare(1, 5, false, &ComputationBuilder::Gt); + TestCompare(5, 5, false, &ComputationBuilder::Gt); + TestCompare(5, 1, true, &ComputationBuilder::Gt); +} + +TEST_F(ScalarComputationsTest, CompareLeU32) { + TestCompare(2, 1, false, &ComputationBuilder::Le); +} + +TEST_F(ScalarComputationsTest, CompareLtU32) { + TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(0, std::numeric_limits::max(), true, + &ComputationBuilder::Lt); +} + +// F32 comparisons. +TEST_F(ScalarComputationsTest, CompareEqF32False) { + TestCompare(2.0, 1.3, false, &ComputationBuilder::Eq); +} + +TEST_F(ScalarComputationsTest, CompareNeF32) { + TestCompare(2.0, 1.3, true, &ComputationBuilder::Ne); +} + +TEST_F(ScalarComputationsTest, CompareGeF32Greater) { + TestCompare(2.0, 1.9, true, &ComputationBuilder::Ge); +} +TEST_F(ScalarComputationsTest, CompareGeF32Equal) { + TestCompare(3.5, 3.5, true, &ComputationBuilder::Ge); +} + +TEST_F(ScalarComputationsTest, CompareGtF32) { + TestCompare(1.0, 5.2, false, &ComputationBuilder::Gt); +} + +TEST_F(ScalarComputationsTest, CompareLeF32) { + TestCompare(2.0, 1.2, false, &ComputationBuilder::Le); +} + +TEST_F(ScalarComputationsTest, CompareLtF32) { + TestCompare(9.0, 7.2, false, &ComputationBuilder::Lt); +} + +// F32 comparisons with exceptional values. The test names encode the +// left/right operands at the end, and use Minf and Mzero for -inf and -0.0. +TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { + TestCompare(-INFINITY, -0.0, true, &ComputationBuilder::Lt); +} +TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { + // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. + TestCompare(-0.0, 0.0, false, &ComputationBuilder::Lt); +} +TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { + TestCompare(0.0, INFINITY, true, &ComputationBuilder::Lt); +} + +TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { + TestCompare(-INFINITY, -0.0, false, &ComputationBuilder::Ge); +} +TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { + // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. + TestCompare(-0.0, 0.0, true, &ComputationBuilder::Ge); +} +TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { + TestCompare(0.0, INFINITY, false, &ComputationBuilder::Ge); +} + +TEST_F(ScalarComputationsTest, ExpScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Exp(builder.ConstantR0(2.0f)); + + ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, LogScalar) { + ComputationBuilder builder(client_, "log"); + builder.Log(builder.ConstantR0(2.0f)); + + ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, TanhScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Tanh(builder.ConstantR0(2.0f)); + + ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); +} + +XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Tanh(builder.ConstantR0(2.0)); + + ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, PowScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); + + ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, ClampScalarHigh) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. + builder.ConstantR0(5.0f), // The operand to be clamped. + builder.ConstantR0(3.0f)); // The upper bound. + + ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, ClampScalarMiddle) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. + builder.ConstantR0(2.5f), // The operand to be clamped. + builder.ConstantR0(3.0f)); // The upper bound. + + ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, ClampScalarLow) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. + builder.ConstantR0(-5.0f), // The operand to be clamped. + builder.ConstantR0(3.0f)); // The upper bound. + + ComputeAndCompareR0(&builder, 2.0, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, MinS32Above) { + TestMinMax(10, 3, 3, &ComputationBuilder::Min); +} + +TEST_F(ScalarComputationsTest, MinS32Below) { + TestMinMax(-100, 3, -100, &ComputationBuilder::Min); +} + +TEST_F(ScalarComputationsTest, MaxS32Above) { + TestMinMax(10, 3, 10, &ComputationBuilder::Max); +} + +TEST_F(ScalarComputationsTest, MaxS32Below) { + TestMinMax(-100, 3, 3, &ComputationBuilder::Max); +} + +TEST_F(ScalarComputationsTest, MinU32Above) { + const uint32 large = std::numeric_limits::max(); + TestMinMax(large, 3, 3, &ComputationBuilder::Min); +} + +TEST_F(ScalarComputationsTest, MinU32Below) { + TestMinMax(0, 5, 0, &ComputationBuilder::Min); +} + +TEST_F(ScalarComputationsTest, MaxU32Above) { + const uint32 large = std::numeric_limits::max(); + TestMinMax(large, 3, large, &ComputationBuilder::Max); +} + +TEST_F(ScalarComputationsTest, MaxU32Below) { + TestMinMax(0, 5, 5, &ComputationBuilder::Max); +} + +TEST_F(ScalarComputationsTest, MinF32Above) { + TestMinMax(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min); +} + +TEST_F(ScalarComputationsTest, MinF32Below) { + TestMinMax(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); +} + +TEST_F(ScalarComputationsTest, MaxF32Above) { + TestMinMax(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); +} + +TEST_F(ScalarComputationsTest, MaxF32Below) { + TestMinMax(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); +} + +TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { + // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. + ComputationBuilder b(client_, TestName()); + b.Div( + b.Sub(b.Mul(b.ConstantR0(1), + b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), + b.Add(b.ConstantR0(7), b.ConstantR0(0)))), + b.ConstantR0(4)), + b.ConstantR0(20)); + + ComputeAndCompareR0(&b, 0.5, {}, error_spec_); +} + +TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { + // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. + ComputationBuilder b(client_, TestName()); + b.Sub(b.Mul(b.ConstantR0(1), + b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), + b.Add(b.ConstantR0(7), b.ConstantR0(0)))), + b.ConstantR0(4)); + + ComputeAndCompareR0(&b, 10, {}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendLlvmBackendFlags(&flag_list); + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc new file mode 100644 index 0000000000..fb1effc8c4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -0,0 +1,395 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests the select-and-scatter XLA operation. + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SelectAndScatterTest : public ClientLibraryTestBase { + public: + SelectAndScatterTest() : builder_(client_, TestName()) { + // Create S32 GE and ADD computations for select and scatter respectively. + ge_s32_ = CreateScalarGeComputation(S32, &builder_); + add_s32_ = CreateScalarAddComputation(S32, &builder_); + ge_f32_ = CreateScalarGeComputation(F32, &builder_); + add_f32_ = CreateScalarAddComputation(F32, &builder_); + max_f32_ = CreateScalarMaxComputation(F32, &builder_); + min_f32_ = CreateScalarMinComputation(F32, &builder_); + } + + ComputationBuilder builder_; + Computation ge_s32_; + Computation add_s32_; + Computation ge_f32_; + Computation add_f32_; + Computation max_f32_; + Computation min_f32_; +}; + +// Test for F32 1D array, with a zero-element input. +XLA_TEST_F(SelectAndScatterTest, R1S0F32) { + const auto operand = builder_.ConstantR1({}); + const auto source = builder_.ConstantR1({}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR1(&builder_, {}, {}, ErrorSpec(1e-7)); +} + +// Test for F32 1D array, when windows do not overlap. +XLA_TEST_F(SelectAndScatterTest, R1F32) { + const auto operand = + builder_.ConstantR1({1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); + const auto source = builder_.ConstantR1({34.f, 42.f}); + const std::vector expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f}; + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +// Test for S32 1D array, when windows do not overlap and the init value is 1. +XLA_TEST_F(SelectAndScatterTest, R1S32) { + const auto operand = builder_.ConstantR1({-1, 0, 6, 4, -4, 10}); + const auto source = builder_.ConstantR1({-10, 20}); + const std::vector expected = {1, 1, -9, 1, 1, 21}; + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + builder_.ConstantR0(1), add_s32_); + ComputeAndCompareR1(&builder_, expected, {}); +} + +// Test for S32 1D array, when windows overlap with each other. +XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { + const auto operand = builder_.ConstantR1({1, 9, 3, 7, 5, 6}); + const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const std::vector expected = {0, 76, 0, 72, 0, 0}; + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{1}, Padding::kValid, source, + builder_.ConstantR0(0), add_s32_); + ComputeAndCompareR1(&builder_, expected, {}); +} + +// Test for S32 2D array, when windows do not overlap. +XLA_TEST_F(SelectAndScatterTest, R2S32) { + const auto operand = + builder_.ConstantR2({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); + const auto source = builder_.ConstantR2({{2, 6}}); + Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + builder_.ConstantR0(0), add_s32_); + ComputeAndCompareR2(&builder_, expected, {}); +} + +// Similar to SelectAndScatterTest.R2S32 but the input is transposed. +XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { + const auto operand = builder_.ConstantR2( + {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); + const auto reshape = + builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); + const auto source = builder_.ConstantR2({{2, 6}}); + Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); + builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + builder_.ConstantR0(0), add_s32_); + ComputeAndCompareR2(&builder_, expected, {}); +} + +// Test for S32 2D array, when windows overlap with each other. +XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) { + const auto operand = + builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = builder_.ConstantR2({{2, 6, 4}}); + Array2D expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + builder_.ConstantR0(0), add_s32_); + ComputeAndCompareR2(&builder_, expected, {}); +} + +// Test for S32 2D array, when the padding is Padding::kSAME. +XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { + const auto operand = + builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = builder_.ConstantR2({{2, 6, 4}}); + Array2D expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{2, 2}, Padding::kSame, source, + builder_.ConstantR0(0), add_s32_); + ComputeAndCompareR2(&builder_, expected, {}); +} + +// Test for S32 2D array, when the padding is Padding::kSAME and windows overlap +// with each other. +XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) { + const auto operand = + builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = + builder_.ConstantR2({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); + Array2D expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + builder_.ConstantR0(0), add_s32_); + ComputeAndCompareR2(&builder_, expected, {}); +} + +XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) { + const auto operand = builder_.ConstantR2( + {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); + const auto source = builder_.ConstantR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + Array2D expected( + {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32Valid) { + Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, + {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}}; + Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array2D pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f}, + {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}}; + Array4D o(4, 6, 15, 220); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D e(4, 6, 15, 220); + e.FillWithPZ(pze); + Array4D s(2, 2, 15, 220); + s.FillWithPZ(pzs); + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32Overlap) { + Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.0f}, + {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}}; + Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array2D pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 8.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; + Array4D o(4, 5, 17, 128); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D e(4, 5, 17, 128); + e.FillWithPZ(pze); + Array4D s(2, 2, 17, 128); + s.FillWithPZ(pzs); + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32OverlapSmall) { + Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.0f}, + {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}}; + Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array2D pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 8.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; + Array4D o(4, 5, 1, 1); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D e(4, 5, 1, 1); + e.FillWithPZ(pze); + Array4D s(2, 2, 1, 1); + s.FillWithPZ(pzs); + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) { + // This test is testing the Reference Util + Array2D pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, + {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}}; + Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array4D o(4, 6, 4, 4); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D s(2, 2, 4, 4); + s.FillWithPZ(pzs); + + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1}, + {2, 3, 1, 1}, false); + ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefSameRandom) { + Array4D o(7, 7, 8, 256); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D s(4, 4, 8, 256); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1}, + Padding::kSame, source, + builder_.ConstantR0(0.0f), add_f32_); + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 2, 1, 1}, + {2, 2, 1, 1}, true); + + ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefSameRandomFullyPadded) { + Array4D o(1, 1, 5, 5); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D s(1, 1, 5, 5); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1}, + Padding::kSame, source, + builder_.ConstantR0(0.0f), add_f32_); + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1}, + {3, 3, 1, 1}, true); + + ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefValidRandom) { + Array4D o(9, 9, 16, 128); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D s(3, 3, 16, 128); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1}, + {3, 3, 1, 1}, false); + + ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefValidRandomSmall) { + Array4D o(3, 3, 4, 4); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D s(1, 1, 4, 4); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0(0.0f), add_f32_); + + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1}, + {3, 3, 1, 1}, false); + + ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { + const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); + const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const std::vector expected = {0, 0, 0, 53, 0, 0, 0}; + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + builder_.ConstantR0(0), max_f32_); + ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { + const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); + const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const float max_float = std::numeric_limits::max(); + const std::vector expected = {max_float, max_float, max_float, 19, + max_float, max_float, max_float}; + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + builder_.ConstantR0(max_float), min_f32_); + ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc new file mode 100644 index 0000000000..5ec9ac95fa --- /dev/null +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SelectTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +TEST_F(SelectTest, SelectScalarF32True) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto on_true = builder.ConstantR0(123.0f); + auto on_false = builder.ConstantR0(42.0f); + auto result = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); +} + +TEST_F(SelectTest, SelectScalarS32True) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto on_true = builder.ConstantR0(-42); + auto on_false = builder.ConstantR0(42); + auto result = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR0(&builder, -42, {}); +} + +TEST_F(SelectTest, SelectScalarF32False) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto on_true = builder.ConstantR0(123.0f); + auto on_false = builder.ConstantR0(42.0f); + auto result = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); +} + +XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR1({}); + auto on_true = builder.ConstantR1({}); + auto on_false = builder.ConstantR1({}); + auto select = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR1({false, true, false, true, false}); + auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + auto select = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, + error_spec_); +} + +XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { + // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector + // is not a constant, but rather the result of comparing two other vectors. + ComputationBuilder builder(client_, TestName()); + auto v1 = builder.ConstantR1({}); + auto v2 = builder.ConstantR1({}); + auto cmp = builder.Eq(v1, v2); + auto on_true = builder.ConstantR1({}); + auto on_false = builder.ConstantR1({}); + auto select = builder.Select(cmp, on_true, on_false); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { + // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is + // not a constant, but rather the result of comparing two other vectors. + ComputationBuilder builder(client_, TestName()); + auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); + auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); + auto cmp = builder.Eq(v1, v2); + auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + auto select = builder.Select(cmp, on_true, on_false); + + ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, + error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { + // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. + ComputationBuilder builder(client_, TestName()); + auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); + auto cmp = builder.Gt(v1, v2); + auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + auto select = builder.Select(cmp, on_true, on_false); + + ComputeAndCompareR1(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, + error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { + // Selects among two R1F32s, which come from parameters. v1 and v2 are + // compared, and selection between them happens based on a gt-comparison mask. + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle v1, v2; + std::unique_ptr param0_data = CreateR1Parameter( + {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&builder, /*data_handle=*/&v1); + std::unique_ptr param1_data = CreateR1Parameter( + {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&builder, /*data_handle=*/&v2); + + auto cmp = builder.Gt(v1, v2); + auto select = builder.Select(cmp, v1, v2); + ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, + {param0_data.get(), param1_data.get()}, + error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { + // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the + // data size passed in and out is large. + ComputationBuilder builder(client_, TestName()); + + // Number of floats in the data passed into and out of the computation. + constexpr int datalen = 15 * 1000; + + // The inputs are initialized with a special pattern where in the first third + // of the data v1[i] > v2[i] and elsewhere it's vice versa. + std::vector v1vec; + std::vector v2vec; + std::vector expected_vec; + for (int i = 0; i < datalen; ++i) { + float smaller = i; + float larger = i * 2; + if (i < datalen / 3) { + v1vec.push_back(larger); + v2vec.push_back(smaller); + } else { + v1vec.push_back(smaller); + v2vec.push_back(larger); + } + expected_vec.push_back(larger); + } + + ComputationDataHandle v1, v2; + std::unique_ptr param0_data = + CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&builder, /*data_handle=*/&v1); + std::unique_ptr param1_data = + CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&builder, /*data_handle=*/&v2); + + auto cmp = builder.Gt(v1, v2); + auto select = builder.Select(cmp, v1, v2); + ComputeAndCompareR1(&builder, expected_vec, + {param0_data.get(), param1_data.get()}, + error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { + // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to + // select between two R1F32s. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({1, -1, 2, -2}); + auto s = builder.ConstantR0(0); + auto cmp = builder.Gt(v, s); + + auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_false = + builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); + auto select = builder.Select(cmp, on_true, on_false); + + ComputeAndCompareR1(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, + error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { + // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to + // select between two R1F32s. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto s = builder.ConstantR0(2.5f); + auto cmp = builder.Gt(v, s); + + auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_false = + builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); + auto select = builder.Select(cmp, on_true, on_false); + + ComputeAndCompareR1(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, + error_spec_); +} + +XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { + for (bool which : {false, true}) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(which); + auto on_true = builder.ConstantR1({}); + auto on_false = builder.ConstantR1({}); + auto select = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); + } +} + +TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto on_true = builder.ConstantR1({-2.5f, 25.5f}); + auto on_false = builder.ConstantR1({10.0f, 5.0f}); + auto select = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, error_spec_); +} + +TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto on_true = builder.ConstantR1({-2.5f, 25.5f}); + auto on_false = builder.ConstantR1({10.0f, 5.0f}); + auto select = builder.Select(pred, on_true, on_false); + + ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, error_spec_); +} +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc new file mode 100644 index 0000000000..e15d744d95 --- /dev/null +++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class SetReturnValueTest : public ClientLibraryTestBase {}; + +TEST_F(SetReturnValueTest, NoSetValue) { + ComputationBuilder builder(client_, "no_set_value"); + auto alpha = builder.ConstantR0(1.0); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto ax = builder.Add(alpha, x); + auto aax = builder.Add(alpha, ax); + + std::vector expected = {1.0, 3.0, 4.0, 0.0, -1.0, + 5.0, 6.0, -2.0, -3.0, 7.0}; + + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(SetReturnValueTest, SetValue) { + ComputationBuilder builder(client_, "set_value"); + auto alpha = builder.ConstantR0(1.0); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto ax = builder.Add(alpha, x); + auto aax = builder.Add(alpha, ax); + auto builder_status = builder.SetReturnValue(ax); + EXPECT_TRUE(builder_status.ok()); + + std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, + 4.0, 5.0, -3.0, -4.0, 6.0}; + + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(SetReturnValueTest, SetValueAndModify) { + ComputationBuilder builder(client_, "set_value_and_modify"); + auto alpha = builder.ConstantR0(1.0); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto ax = builder.Add(alpha, x); + auto aax = builder.Add(alpha, ax); + auto builder_status = builder.SetReturnValue(ax); + EXPECT_TRUE(builder_status.ok()); + auto aaax = builder.Add(alpha, aax); + + std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, + 4.0, 5.0, -3.0, -4.0, 6.0}; + + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { + ComputationBuilder builder(client_, "set_value_multiple_times_and_modify"); + auto alpha = builder.ConstantR0(1.0); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto ax = builder.Add(alpha, x); + auto aax = builder.Add(alpha, ax); + auto builder_status = builder.SetReturnValue(aax); + EXPECT_TRUE(builder_status.ok()); + auto aaax = builder.Add(alpha, aax); + builder_status = builder.SetReturnValue(ax); + EXPECT_TRUE(builder_status.ok()); + auto aaaax = builder.Add(alpha, aaax); + + std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, + 4.0, 5.0, -3.0, -4.0, 6.0}; + + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc new file mode 100644 index 0000000000..d63582fb98 --- /dev/null +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that slice operations can be performed. + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SliceTest : public ClientLibraryTestBase { + protected: + template + void RunSliceTenToTwo() { + std::vector constant; + for (int i = 0; i < 10; ++i) { + constant.push_back(static_cast(i)); + } + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1(constant); + builder.Slice(original, {2}, {4}); + + const std::vector expected = {static_cast(2), + static_cast(3)}; + ComputeAndCompareR1(&builder, expected, {}); + } +}; + +XLA_TEST_F(SliceTest, SliceZeroToZeroF32) { + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1({}); + builder.Slice(original, {0}, {0}); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(SliceTest, SliceTenToZeroF32) { + ComputationBuilder builder(client_, TestName()); + std::vector constant(10, 0.3); + auto original = builder.ConstantR1(constant); + builder.Slice(original, {7}, {7}); + + ComputeAndCompareR1(&builder, {}, {}); +} + +TEST_F(SliceTest, SliceTenToTwoF32) { RunSliceTenToTwo(); } + +XLA_TEST_F(SliceTest, SliceTenToTwoF64) { RunSliceTenToTwo(); } + +TEST_F(SliceTest, SliceTenToTwoU32) { RunSliceTenToTwo(); } + +TEST_F(SliceTest, SliceTenToTwoS32) { RunSliceTenToTwo(); } + +XLA_TEST_F(SliceTest, SliceTenToTwoU64) { RunSliceTenToTwo(); } + +XLA_TEST_F(SliceTest, SliceTenToTwoS64) { RunSliceTenToTwo(); } + +TEST_F(SliceTest, SliceTenToTen) { + const std::vector values = {0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1(values); + builder.Slice(original, {0}, {10}); + + ComputeAndCompareR1(&builder, values, {}, ErrorSpec(0.000001)); +} + +TEST_F(SliceTest, SliceLastFourOf1024) { + std::vector values(1024); + std::iota(values.begin(), values.end(), 0.0); + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1(values); + builder.Slice(original, {1024 - 4}, {1024}); + + const std::vector expected = {1020, 1021, 1022, 1023}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); +} + +// TODO(b/28491443): Fix wrong result on CPU and GPU. Failed on +// 2016-05-01. Also b/28508652 +TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { + std::vector values(4096); + std::iota(values.begin(), values.end(), 0.0); + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1(values); + builder.Slice(original, {7}, {7 + 1024}); + + std::vector expected(1024); + std::iota(values.begin(), values.end(), 7.0); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); +} + +XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); + builder.Slice(original, {0, 0}, {0, 0}); + + ComputeAndCompareR2(&builder, Array2D(0, 0), {}); +} + +XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); + builder.Slice(original, {0, 15}, {0, 20}); + + ComputeAndCompareR2(&builder, Array2D(0, 5), {}); +} + +XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); + builder.Slice(original, {1, 0}, {3, 0}); + + ComputeAndCompareR2(&builder, Array2D(2, 0), {}); +} + +XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { + Array2D values(256, 256); + for (int row = 0; row < 256; ++row) { + for (int col = 0; col < 256; ++col) { + values(row, col) = (row << 10) | col; + } + } + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR2FromArray2D(values); + builder.Slice(original, {128, 128}, {256, 256}); + + Array2D expected(128, 128); + for (int row = 0; row < 128; ++row) { + for (int col = 0; col < 128; ++col) { + expected(row, col) = ((row + 128) << 10) | (col + 128); + } + } + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); +} + +// Tests: (f32[1,4096], starts={0, 3072}, limits={1, 4096}) -> f32[1,1024]) +TEST_F(SliceTest, Slice_1x4096_To_1x1024) { + Array2D values(1, 4096); + std::iota(values.data(), values.data() + 4096, 0.0); + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR2FromArray2D(values); + builder.Slice(original, {0, 3072}, {1, 4096}); + + Array2D expected(1, 1024); + std::iota(expected.data(), expected.data() + 1024, 3072.0); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); +} + +// Tests slice: (f32[16,4], starts={0, 0}, limits={16, 2}) -> f32[16,2] +TEST_F(SliceTest, Slice_16x4_To_16x2) { + Array2D values(16, 4); + Array2D expected(16, 2); + for (int row = 0; row < 16; ++row) { + for (int col = 0; col < 4; ++col) { + values(row, col) = (row << 10) | col; + if (col < 2) { + expected(row, col) = (row << 10) | col; + } + } + } + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR2FromArray2D(values); + builder.Slice(original, {0, 0}, {16, 2}); + ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); +} + +// Tests: (f32[2, 2, 24, 256], starts = {1, 0, 8, 0}, ends = {2, 2, 16, 128} +TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { + Array4D values(2, 2, 24, 256); + values.FillRandom(3.14f); + auto expected = + ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}); + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR4FromArray4D(values); + builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}); + ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); +} + +struct R2Spec { + int64 input_dim0; + int64 input_dim1; + std::array slice_starts; + std::array slice_limits; + Layout layout; +}; + +// Parameterized test that generates patterned R2 values, slices them according +// to the R2Spec, and compares the results with the ReferenceUtil version. +class SliceR2Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(SliceR2Test, DoIt) { + const R2Spec& spec = GetParam(); + Array2D input(spec.input_dim0, spec.input_dim1); + input.FillUnique(); + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2FromArray2D(input); + builder.Slice(a, spec.slice_starts, spec.slice_limits); + + std::unique_ptr> expected = + ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits); + ComputeAndCompareR2(&builder, *expected, {}); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR2TestInstantiation, SliceR2Test, + ::testing::Values( + R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})}, + R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})}, + R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})}, + R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})}, + R2Spec {256, 400, {{0, 300}}, {{256, 400}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {500, 400, {{111, 123}}, {{300, 257}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {500, 400, {{111, 123}}, {{300, 400}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {384, 512, {{128, 256}}, {{256, 384}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {357, 512, {{111, 256}}, {{301, 384}}, + LayoutUtil::MakeLayout({1, 0})} + ) +); +// clang-format on + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h new file mode 100644 index 0000000000..7f987a21ca --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Macros for use in enabling/disabling tests on particular +// platforms. Marking a gunit test as disabled still ensures that it +// compiles. +// +// Implementation note: the macros are structured as follows: +// * Define the disabled macro to just pass the test name through (which, in +// effect, does not disable it at all) +// * If a XLA_TEST_BACKEND_$TARGET macro indicates we're compiling for +// $TARGET platform, make the disabled macro truly disable the test; i.e. by +// redefining the DISABLED_ON_$TARGET macro to prepend "DISABLED_" to the test +// name. + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +// Use this macro instead of directly using TEST_P for parameterized tests, +// otherwise DISABLED_ON_* macros nested in TEST_P will not get expanded since +// TEST_P stringifies its argument. That makes the test disabled for all targets +// when any one of the DISABLED_ON_* macro is used, and the test will just pass. +// TODO(b/29122096): Remove this once TEST_P fixes this problem. +#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name) + +#define DISABLED_ON_CPU(X) X +#define DISABLED_ON_CPU_PARALLEL(X) X +#define DISABLED_ON_GPU(X) X + +// We need this macro instead of pasting directly to support nesting +// the DISABLED_ON_FOO macros, as in the definition of DISABLED_ON_CPU. +// Otherwise the pasting is applied before macro expansion completes. +#define XLA_TEST_PASTE(A, B) A##B + +// We turn off clang-format so we can indent the macros for readability. +// clang-format off + +#ifdef XLA_TEST_BACKEND_CPU +# undef DISABLED_ON_CPU +# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X) +#endif // XLA_TEST_BACKEND_CPU + +#ifdef XLA_TEST_BACKEND_CPU_PARALLEL +# undef DISABLED_ON_CPU +# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X) +# undef DISABLED_ON_CPU_PARALLEL +# define DISABLED_ON_CPU_PARALLEL(X) XLA_TEST_PASTE(DISABLED_, X) +#endif // XLA_TEST_BACKEND_CPU_PARALLEL + +#ifdef XLA_TEST_BACKEND_GPU +# undef DISABLED_ON_GPU +# define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X) +#endif // XLA_TEST_BACKEND_GPU + +// clang-format on + +#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name) + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h new file mode 100644 index 0000000000..6a23df4d3c --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace test_utils { + +// A class which generates pseudorandom numbers of a given type within a given +// range. Not cryptographically secure and likely not perfectly evenly +// distributed across the range but sufficient for most tests. +template +class PseudorandomGenerator { + public: + explicit PseudorandomGenerator(NativeT min_value, NativeT max_value, + uint32 seed) + : min_(min_value), max_(max_value), generator_(seed) {} + + // Get a pseudorandom value. + NativeT get() { + std::uniform_real_distribution<> distribution; + return static_cast(min_ + + (max_ - min_) * distribution(generator_)); + } + + private: + NativeT min_; + NativeT max_; + std::mt19937 generator_; +}; + +// Convenience function for creating a rank-2 array with arbitrary layout. +template +std::unique_ptr CreateR2LiteralWithLayout( + std::initializer_list> values, + tensorflow::gtl::ArraySlice minor_to_major) { + auto literal = MakeUnique(); + const int64 d0 = values.size(); + const int64 d1 = values.begin()->size(); + LiteralUtil::PopulateWithValue(0, {d0, d1}, literal.get()); + *literal->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout(minor_to_major); + TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); + + int64 dim0 = 0; + for (auto inner_list : values) { + int64 dim1 = 0; + for (auto value : inner_list) { + LiteralUtil::Set(literal.get(), {dim0, dim1}, value); + ++dim1; + } + ++dim0; + } + return literal; +} + +// Convenience function for creating a rank-3 array with arbitrary layout. +template +std::unique_ptr CreateR3LiteralWithLayout( + std::initializer_list>> + values, + tensorflow::gtl::ArraySlice minor_to_major) { + auto literal = MakeUnique(); + const int64 d0 = values.size(); + const int64 d1 = values.begin()->size(); + const int64 d2 = values.begin()->begin()->size(); + LiteralUtil::PopulateWithValue(0, {d0, d1, d2}, literal.get()); + *literal->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout(minor_to_major); + TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); + + int64 dim0 = 0; + for (auto inner_list : values) { + int64 dim1 = 0; + for (auto inner_inner_list : inner_list) { + int64 dim2 = 0; + for (auto value : inner_inner_list) { + LiteralUtil::Set(literal.get(), {dim0, dim1, dim2}, value); + ++dim2; + } + ++dim1; + } + ++dim0; + } + return literal; +} + +} // namespace test_utils +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc new file mode 100644 index 0000000000..79f251bbc4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class TransposeTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; + + protected: + void TestTransposeConstant021(size_t n1, size_t n2, size_t n3); +}; + +XLA_TEST_F(TransposeTest, Transpose0x0) { + ComputationBuilder builder(client_, "Transpose"); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); + auto result = builder.Transpose(lhs, {1, 0}); + + ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); +} + +XLA_TEST_F(TransposeTest, Transpose0x42) { + ComputationBuilder builder(client_, "Transpose"); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); + auto result = builder.Transpose(lhs, {1, 0}); + + ComputeAndCompareR2(&builder, Array2D(42, 0), {}, error_spec_); +} + +XLA_TEST_F(TransposeTest, Transpose7x0) { + ComputationBuilder builder(client_, "Transpose"); + auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); + auto result = builder.Transpose(lhs, {1, 0}); + + ComputeAndCompareR2(&builder, Array2D(0, 7), {}, error_spec_); +} + +TEST_F(TransposeTest, Transpose2x2) { + ComputationBuilder builder(client_, "Transpose"); + auto lhs = builder.ConstantR2({ + {1.0, 2.0}, {3.0, 4.0}, + }); + auto result = builder.Transpose(lhs, {1, 0}); + + Array2D expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); + + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { + ComputationBuilder builder(client_, "Transpose"); + auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); + auto result = builder.Transpose(operand, {1, 2, 0}); + + ComputeAndCompareR3(&builder, Array3D(2, 3, 0), {}); +} + +TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { + ComputationBuilder builder(client_, "Transpose"); + auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); + auto result = builder.Transpose(operand, {1, 2, 0}); + + Array3D expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}}); + + ComputeAndCompareR3(&builder, expected, {}); +} + +TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { + ComputationBuilder builder(client_, "Transpose"); + auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); + auto result = builder.Transpose(operand, {2, 1, 0}); + + Array3D expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}}); + + ComputeAndCompareR3(&builder, expected, {}); +} + +TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { + ComputationBuilder builder(client_, "Transpose"); + auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); + auto result = builder.Transpose(operand, {0, 1, 2}); + + Array3D expected({{{1, 2, 3}, {4, 5, 6}}}); + + ComputeAndCompareR3(&builder, expected, {}); +} + +TEST_F(TransposeTest, MultiTranspose3x2) { + Array2D input({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}); + Array2D transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}}); + + for (int transposes = 0; transposes <= 10; ++transposes) { + ComputationBuilder builder(client_, "Transpose"); + auto computed = builder.ConstantR2FromArray2D(input); + for (int i = 0; i < transposes; ++i) { + computed = builder.Transpose(computed, {1, 0}); + } + const Array2D& expected = transposes % 2 == 0 ? input : transposed; + ComputeAndCompareR2(&builder, expected, {}, error_spec_); + } +} + +// Test for transposing [1x1] matrix. +TEST_F(TransposeTest, Small_1x1) { + auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); + + ComputationBuilder builder(client_, "transpose_1x1"); + auto operand = builder.ConstantR2FromArray2D(*aoperand); + builder.Transpose(operand, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*aoperand); + ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); +} + +// Test for transposing [2x2] matrix. +TEST_F(TransposeTest, Small_2x2) { + auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); + + ComputationBuilder builder(client_, "transpose_2x2"); + auto operand = builder.ConstantR2FromArray2D(*aoperand); + builder.Transpose(operand, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*aoperand); + ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); +} + +void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { + Array3D aoperand(n1, n2, n3); + Array3D expected(n1, n3, n2); + for (size_t i = 0; i < n1; ++i) { + for (size_t j = 0; j < n2; ++j) { + for (size_t k = 0; k < n3; ++k) { + aoperand(i, j, k) = i * n3 * n2 + j * n3 + k; + expected(i, k, j) = aoperand(i, j, k); + } + } + } + + ComputationBuilder builder(client_, TestName()); + auto operand = builder.ConstantR3FromArray3D(aoperand); + builder.Transpose(operand, {0, 2, 1}); + + ComputeAndCompareR3(&builder, expected, {}); +} + +TEST_F(TransposeTest, TransposeConstant021_SingleIncompleteTilePerLayer) { + TestTransposeConstant021(2, 2, 3); +} + +TEST_F(TransposeTest, TransposeConstant021_SingleCompleteTilePerLayer) { + TestTransposeConstant021(2, 32, 32); +} + +TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) { + TestTransposeConstant021(2, 70, 35); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc new file mode 100644 index 0000000000..cea9316a6d --- /dev/null +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -0,0 +1,415 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class TupleTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +// Tests the creation of tuple data. +XLA_TEST_F(TupleTest, TupleCreate) { + ComputationBuilder builder(client_, TestName()); + + const float constant_scalar = 7.3f; + std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; + std::initializer_list> constant_matrix = { + {1.1f, 2.2f, 3.5f}, // row 0 + {4.8f, 5.0f, 6.7f}, // row 1 + }; + auto result = builder.Tuple({builder.ConstantR0(constant_scalar), + builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); + + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +// Tests the creation of tuple data. +XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { + ComputationBuilder builder(client_, TestName()); + + auto result = builder.Tuple( + {builder.ConstantR0(7.0), builder.ConstantR1({})}); + + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), + LiteralUtil::CreateR1({}).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +// Tests the creation of an empty tuple. +XLA_TEST_F(TupleTest, EmptyTupleCreate) { + ComputationBuilder builder(client_, TestName()); + auto result = builder.Tuple({}); + auto expected = LiteralUtil::MakeTuple({}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +// Trivial test for extracting a tuple element with GetTupleElement. +XLA_TEST_F(TupleTest, GetTupleElement) { + ComputationBuilder builder(client_, TestName()); + std::initializer_list constant_vector = {1.f, 2.f, 3.f}; + std::initializer_list> constant_matrix = { + {1.f, 2.f, 3.f}, // row 0 + {4.f, 5.f, 6.f}, // row 1 + }; + auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); + auto matrix_element = builder.GetTupleElement(tuple_data, 1); + ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, + error_spec_); +} + +// Trivial test for extracting a tuple element with GetTupleElement. +XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { + ComputationBuilder builder(client_, TestName()); + auto tuple_data = builder.Tuple( + {builder.ConstantR1({}), + builder.ConstantR2FromArray2D(Array2D(0, 101))}); + auto matrix_element = builder.GetTupleElement(tuple_data, 1); + ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); +} + +// Extracts both elements from a tuple with GetTupleElement and then adds them +// together. +XLA_TEST_F(TupleTest, AddTupleElements) { + ComputationBuilder builder(client_, TestName()); + std::initializer_list constant_vector = {1.f, 2.f, 3.f}; + std::initializer_list> constant_matrix = { + {1.f, 2.f, 3.f}, // row 0 + {4.f, 5.f, 6.f}, // row 1 + }; + auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); + auto vector_element = builder.GetTupleElement(tuple_data, 0); + auto matrix_element = builder.GetTupleElement(tuple_data, 1); + auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); + auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); + auto result = builder.Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); + + Array2D expected({ + {2.f, 4.f, 6.f}, // row 0 + {5.f, 7.f, 9.f}, // row 1 + }); + ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// Extracts both elements from a tuple and then puts them into a new tuple in +// the opposite order. +XLA_TEST_F(TupleTest, TupleGTEToTuple) { + ComputationBuilder builder(client_, TestName()); + std::initializer_list constant_vector = {1.f, 2.f, 3.f}; + std::initializer_list> constant_matrix = { + {1.f, 2.f, 3.f}, // row 0 + {4.f, 5.f, 6.f}, // row 1 + }; + auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); + auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), + builder.GetTupleElement(tuple_data, 0)}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2(constant_matrix).get(), + LiteralUtil::CreateR1(constant_vector).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +// Builds two new tuples from an existing tuple (by means of GetTupleElement), +// then adds up the components of the new tuples. +XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { + // + // v------ --(GTE 0)-- --(GTE 0)---------- + // \ / \ / \ + // (tuple)-- (tuple01)-- \ + // / | \ / \ \ + // m------ | --(GTE 1)-- --(GTE 1)------------ \ + // | \ \ + // | (add) + // | / / + // |--------(GTE 1)-- --(GTE 0)------------ / + // \ \ / / + // \ (tuple10)-- / + // \ / \ / + // -----(GTE 0)-- --(GTE 1)---------- + ComputationBuilder builder(client_, TestName()); + std::initializer_list constant_vector = {1.f, 2.f, 3.f}; + std::initializer_list> constant_matrix = { + {1.f, 2.f, 3.f}, // row 0 + {4.f, 5.f, 6.f}, // row 1 + }; + auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); + auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0), + builder.GetTupleElement(tuple_data, 1)}); + auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1), + builder.GetTupleElement(tuple_data, 0)}); + auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0); + auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1); + auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1); + auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0); + + auto addvectors = builder.Add(vector_from_01, vector_from_10); + auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); + + auto result = builder.Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); + + Array2D expected({ + {4.f, 8.f, 12.f}, // row 0 + {10.f, 14.f, 18.f}, // row 1 + }); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { + // Tests a selection between tuples with "false" path taken. + ComputationBuilder builder(client_, TestName()); + + std::initializer_list vec1 = {1.f, 2.f, 3.f}; + std::initializer_list vec2 = {2.f, 4.f, 6.f}; + auto tuple12 = builder.Tuple( + {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); + auto tuple21 = builder.Tuple( + {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + + auto select = + builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, TuplesInAMap) { + Computation tuple_computation; + { + // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples. + // + // Need to put a select in there to prevent HLO-level optimizations from + // optimizing out the tuples. + ComputationBuilder b(client_, "sort_square"); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto x2 = b.Mul(x, x); + auto x_smaller_tuple = b.Tuple({x, x2}); + auto x2_smaller_tuple = b.Tuple({x2, x}); + auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); + auto smaller = b.GetTupleElement(sorted, 0); + auto greater = b.GetTupleElement(sorted, 1); + b.Add(greater, b.Mul(b.ConstantR0(100.0f), smaller)); + auto computation_status = b.Build(); + ASSERT_IS_OK(computation_status.status()); + tuple_computation = computation_status.ConsumeValueOrDie(); + } + + ComputationBuilder b(client_, TestName()); + auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); + b.Map({input}, tuple_computation); + ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { + // Tests a selection between tuples with "true" path taken. + ComputationBuilder builder(client_, TestName()); + + std::initializer_list vec1 = {1.f, 2.f, 3.f}; + std::initializer_list vec2 = {2.f, 4.f, 6.f}; + auto tuple12 = builder.Tuple( + {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); + auto tuple21 = builder.Tuple( + {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + + auto select = + builder.Select(builder.ConstantR0(true), tuple12, tuple21); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), + LiteralUtil::CreateR1(vec2).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { + // Tests a selection between tuples but the final result is an element of the + // tuple, not the whole tuple. + ComputationBuilder builder(client_, TestName()); + + std::initializer_list vec1 = {1.f, 2.f, 3.f}; + std::initializer_list vec2 = {2.f, 4.f, 6.f}; + auto tuple12 = builder.Tuple( + {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); + auto tuple21 = builder.Tuple( + {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + + auto select = + builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto element = builder.GetTupleElement(select, 0); + + ComputeAndCompareR1(&builder, vec2, {}, error_spec_); +} + +// Cascaded selects between tuple types. +XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { + // + // vec1 vec2 vec2 vec1 + // | | | | + // | | | | + // (tuple 12) (tuple 21) + // \ / + // \ / + // \ / + // true -- --(GTE 0)--(select 1) + // \ / | + // (pred tuple)-- | --(GTE 0)-- + // / \ V / \ + // false -- --(GTE 1)--(select 2)-- --(add) + // / \ / + // / --(GTE 1)-- + // / + // (tuple 21) + ComputationBuilder builder(client_, TestName()); + + std::initializer_list vec1 = {1.f, 2.f, 3.f}; + std::initializer_list vec2 = {2.f, 4.f, 6.f}; + + auto pred_tuple = builder.Tuple( + {builder.ConstantR0(true), builder.ConstantR0(false)}); + auto tuple12 = builder.Tuple( + {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); + auto tuple21 = builder.Tuple( + {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + + auto select1 = + builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); + auto select2 = + builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); + auto result = builder.Add(builder.GetTupleElement(select2, 0), + builder.GetTupleElement(select2, 1)); + + ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, + DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) { + // Similar to SelectBetweenTuples, but the constants are shared between the + // input tuples. + ComputationBuilder builder(client_, TestName()); + + std::initializer_list vec1 = {1.f, 2.f, 3.f}; + std::initializer_list vec2 = {2.f, 4.f, 6.f}; + auto c1 = builder.ConstantR1(vec1); + auto c2 = builder.ConstantR1(vec2); + auto tuple12 = builder.Tuple({c1, c2}); + auto tuple21 = builder.Tuple({c2, c1}); + + auto select = + builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, NestedTuples) { + ComputationBuilder builder(client_, TestName()); + auto inner_tuple = builder.Tuple( + {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); + auto outer_tuple = + builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + + auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); + auto expected_s = LiteralUtil::CreateR0(42.0); + auto expected_inner_tuple = + LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); + auto expected = + LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { + ComputationBuilder builder(client_, TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {3}); + Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); + Shape outer_tuple_shape = + ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape}); + + auto input = builder.Parameter(0, outer_tuple_shape, "input"); + auto gte0 = builder.GetTupleElement(input, 0); + auto gte1 = builder.GetTupleElement(gte0, 1); + builder.Add(gte1, builder.ConstantR1({10.0, 11.0, 12.0})); + + std::unique_ptr data = + client_ + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::MakeTuple( + { + LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), + }) + .get(), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + })) + .ConsumeValueOrDie(); + + std::vector arguments = {data.get()}; + const std::vector expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0}; + ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc new file mode 100644 index 0000000000..fdbaa0d178 --- /dev/null +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class UnaryOpTest : public ClientLibraryTestBase { + protected: + template + T inf() { + return std::numeric_limits::infinity(); + } + template + void AbsSize0TestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1({}); + auto abs = builder.Abs(arg); + + ComputeAndCompareR1(&builder, {}, {}); + } + + template + void AbsTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); + auto abs = builder.Abs(arg); + + ComputeAndCompareR1(&builder, {2, 25, 0, 123, inf(), inf()}, {}); + } + + template + void SignTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1( + {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); + auto sign = builder.Sign(arg); + + ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); + } + + template + void SignAbsTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1({-2, 25, 0, -123}); + auto sign = builder.Sign(arg); + auto abs = builder.Abs(arg); + builder.Sub(builder.Mul(sign, abs), arg); + + ComputeAndCompareR1(&builder, {0, 0, 0, 0}, {}); + } +}; + +template <> +int UnaryOpTest::inf() { + return 2147483647; +} + +XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { + AbsSize0TestHelper(); + AbsSize0TestHelper(); +} + +TEST_F(UnaryOpTest, AbsTestR1) { + AbsTestHelper(); + AbsTestHelper(); +} + +TEST_F(UnaryOpTest, AbsTestR0) { + ComputationBuilder builder(client_, TestName()); + auto argi = builder.ConstantR0(-5); + auto absi = builder.Abs(argi); + auto argf = builder.ConstantR0(-3.0f); + auto absf = builder.Abs(argf); + auto argf0 = builder.ConstantR0(-0.0f); + auto absf0 = builder.Abs(argf0); + builder.Add(absf0, builder.Add(absf, builder.ConvertElementType( + absi, PrimitiveType::F32))); + + ComputeAndCompareR0(&builder, 8.0f, {}); +} + +TEST_F(UnaryOpTest, SignTestR0) { + ComputationBuilder builder(client_, TestName()); + auto argi = builder.ConstantR0(-5); + auto absi = builder.Sign(argi); + auto argf = builder.ConstantR0(-4.0f); + auto absf = builder.Sign(argf); + auto argf0 = builder.ConstantR0(-0.0f); + auto absf0 = builder.Sign(argf0); + builder.Add(absf0, builder.Add(absf, builder.ConvertElementType( + absi, PrimitiveType::F32))); + + ComputeAndCompareR0(&builder, -2.0f, {}); +} + +TEST_F(UnaryOpTest, SignTestR1) { + SignTestHelper(); + SignTestHelper(); +} + +TEST_F(UnaryOpTest, SignAbsTestR1) { + SignAbsTestHelper(); + SignAbsTestHelper(); +} + +TEST_F(UnaryOpTest, UnsignedAbsTestR1) { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1( + {2, 25, 0, 123, std::numeric_limits::max()}); + auto abs = builder.Abs(arg); + + ComputeAndCompareR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); +} + +TEST_F(UnaryOpTest, UnsignedSignTestR1) { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1( + {2, 25, 0, 123, std::numeric_limits::max()}); + auto sign = builder.Sign(arg); + + ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); +} + +TEST_F(UnaryOpTest, SignAbsTestR2) { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); + auto sign = builder.Sign(arg); + auto abs = builder.Abs(arg); + builder.Sub(builder.Mul(sign, abs), arg); + + ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc new file mode 100644 index 0000000000..7f3d7d9cb4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -0,0 +1,235 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class VecOpsReduceTest : public ClientLibraryTestBase { + public: + VecOpsReduceTest() : builder_(client_, TestName()) {} + + ComputationDataHandle BuildSampleConstantCube() { + // clang-format off + Array3D x3d({ + {{1.0, 2.0, 3.0}, // | dim 1 // } plane 0 in dim 0 + {4.0, 5.0, 6.0}}, // V // } + // ---- dim 2 ----> + {{1.0, 2.0, 3.0}, // } plane 1 in dim 0 + {4.0, 5.0, 6.0}}, + {{1.0, 2.0, 3.0}, // } plane 2 in dim 0 + {4.0, 5.0, 6.0}}}); + // clang-format on + return builder_.ConstantR3FromArray3D(x3d); + } + + ComputationBuilder builder_; + ErrorSpec errspec_{1e-3, 0}; +}; + +TEST_F(VecOpsReduceTest, AddReduceR1F32) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + + auto x = builder_.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); + + ComputeAndCompareR0(&builder_, -4.2f, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + + std::vector input(3000); + std::iota(input.begin(), input.end(), 100.0f); + + auto x = builder_.ConstantR1(input); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); + + float expected = std::accumulate(input.begin(), input.end(), 0.0f); + ComputeAndCompareR0(&builder_, expected, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, MaxReduceR1F32) { + auto max_reducer = CreateScalarMax(); + + auto x = builder_.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto max_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); + + ComputeAndCompareR0(&builder_, 2.6f, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { + auto max_reducer = CreateScalarMax(); + + auto x = builder_.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto max_reduce = + builder_.Reduce(x, builder_.ConstantR0(4.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); + + ComputeAndCompareR0(&builder_, 4.0f, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + + // clang-format off + auto x = builder_.ConstantR2({ + {1.0, 2.0, 3.0}, // | dim 0 + {4.0, 5.0, 6.0}}); // | + // ------ dim 1 ---------- + // clang-format on + + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); + + ComputeAndCompareR1(&builder_, {6.0, 15.0}, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + + // clang-format off + auto x = builder_.ConstantR2({ + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}}); + // clang-format on + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); + + ComputeAndCompareR1(&builder_, {5.0, 7.0, 9.0}, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{2}); + + Array2D expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}}); + + ComputeAndCompareR2(&builder_, expected_array, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); + + Array2D expected_array( + {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}}); + + ComputeAndCompareR2(&builder_, expected_array, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); + + Array2D expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}}); + + ComputeAndCompareR2(&builder_, expected_array, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1, 2}); + + ComputeAndCompareR1(&builder_, {21.0, 21.0, 21.0}, {}, errspec_); +} + +XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 2}); + + ComputeAndCompareR1(&builder_, {18.0, 45.0}, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1}); + + ComputeAndCompareR1(&builder_, {15.0, 21.0, 27.0}, {}, errspec_); +} + +TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { + auto sum_reducer = CreateScalarAddComputation(F32, &builder_); + auto x = BuildSampleConstantCube(); + auto add_reduce = + builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1, 2}); + + ComputeAndCompareR0(&builder_, 63.0, {}, errspec_); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc new file mode 100644 index 0000000000..d9fc1e1e8f --- /dev/null +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -0,0 +1,423 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class VecOpsSimpleTest : public ClientLibraryTestBase { + public: + explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr) + : ClientLibraryTestBase(platform, + /*disabled_pass_names=*/{"algsimp", "inline"}) {} + + ErrorSpec error_spec_{0.0001}; +}; + +TEST_F(VecOpsSimpleTest, ExpTenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto exp = builder.Exp(x); + + std::vector expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02, + 8.1662, 9.9742, 6.7379e-03, 4.0657e-01, + 9.0718e-02, 4.9530}; + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(VecOpsSimpleTest, ExpManyValues) { + for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { + ComputationBuilder builder(client_, TestName()); + std::vector exponents; + for (int i = 0; i < count; ++i) { + exponents.push_back(i / static_cast(count)); + } + auto x = builder.ConstantR1(exponents); + auto exp = builder.Exp(x); + + std::vector expected; + for (float exponent : exponents) { + expected.push_back(std::exp(exponent)); + } + + ComputeAndCompareR1(&builder, expected, {}, + ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); + } +} + +TEST_F(VecOpsSimpleTest, ExpIn4D) { + ComputationBuilder builder(client_, TestName()); + Array4D exponents(2, 2, 2, 2); + + std::vector exponents_vector; + std::vector expected_vector; + for (int i = 0; i < exponents.num_elements(); ++i) { + exponents_vector.push_back(static_cast(i) / + exponents.num_elements()); + expected_vector.push_back(std::exp(exponents_vector.back())); + } + exponents.SetValues(exponents_vector); + + Array4D expected(2, 2, 2, 2, expected_vector); + + auto x = builder.ConstantR4FromArray4D(exponents); + auto exp = builder.Exp(x); + + ComputeAndCompareR4(&builder, expected, {}, + ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); +} + +TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + builder.Neg(x); + + std::vector expected = {-2.1, 2.6, -2.6, 4.0, -2.1, + -2.3, 5.0, 0.9, 2.4, -1.6}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); + builder.Neg(x); + + std::vector expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, NegateUint32Values) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {0, 1, 42, static_cast(-1), static_cast(-12)}); + builder.Neg(x); + std::vector expected = {0, static_cast(-1), + static_cast(-42), 1, 12}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, SquareTenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + builder.SquareF32(x); + + std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, + 5.29, 25., 0.81, 5.76, 2.56}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + builder.ReciprocalF32(x); + + std::vector expected = { + 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, + 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { + ComputationBuilder builder(client_, TestName()); + auto add = CreateScalarAddComputation(F32, &builder); + + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = builder.ConstantR1( + {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + auto max = builder.Map({x, y}, add); + + std::vector expected = {1.7, -3.2, -0.4, -3.8, 5.9, + 0.1, -6.8, 4., -1., 2.2}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +TEST_F(VecOpsSimpleTest, MaxTenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = builder.ConstantR1( + {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + auto max = builder.Max(x, y); + + std::vector expected = {2.1, -0.6, 2.6, 0.2, 3.8, + 2.3, -1.8, 4.9, 1.4, 1.6}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { + // Similar to MaxTenValues, except that the inputs come from params rather + // than constants. + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle v1, v2; + std::unique_ptr param0_data = CreateR1Parameter( + {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&builder, /*data_handle=*/&v1); + std::unique_ptr param1_data = CreateR1Parameter( + {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&builder, /*data_handle=*/&v2); + + auto max = builder.Max(v1, v2); + ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, + {param0_data.get(), param1_data.get()}, + error_spec_); +} + +TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { + // Similar to MaxTenValuesFromParams, except that the data size passed in and + // out is large. + ComputationBuilder builder(client_, TestName()); + + // Number of floats in the data passed into and out of the computation. + constexpr int datalen = 15 * 1000; + + // The inputs are initialized with a special pattern where in the first third + // of the data v1[i] > v2[i] and elsewhere it's vice versa. + std::vector v1vec; + std::vector v2vec; + std::vector expected_vec; + for (int i = 0; i < datalen; ++i) { + float smaller = i; + float larger = i * 2; + if (i < datalen / 3) { + v1vec.push_back(larger); + v2vec.push_back(smaller); + } else { + v1vec.push_back(smaller); + v2vec.push_back(larger); + } + expected_vec.push_back(larger); + } + + ComputationDataHandle v1, v2; + std::unique_ptr param0_data = + CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&builder, /*data_handle=*/&v1); + std::unique_ptr param1_data = + CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&builder, /*data_handle=*/&v2); + + auto max = builder.Max(v1, v2); + ComputeAndCompareR1(&builder, expected_vec, + {param0_data.get(), param1_data.get()}, + error_spec_); +} + +TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = builder.ConstantR0(0); + auto max = builder.Max(x, y); + + std::vector expected = {2.1, 0.0, 2.6, 0.0, 2.1, + 2.3, 0.0, 0.0, 0.0, 1.6}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, MinTenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = builder.ConstantR1( + {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + auto min = builder.Min(x, y); + + std::vector expected = {-0.4, -2.6, -3.0, -4.0, 2.1, + -2.2, -5.0, -0.9, -2.4, 0.6}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, MinMaxTenValues) { + ComputationBuilder builder(client_, TestName()); + auto zero = builder.ConstantR0(0); + auto one = builder.ConstantR0(1); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + auto clamp = builder.Min(builder.Max(x, zero), one); + + std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, + 0.9, 0.0, 0.1, 0.0, 0.6}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { + ComputationBuilder builder(client_, TestName()); + auto zero = builder.ConstantR0(0); + auto one = builder.ConstantR0(1); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + auto clamp = builder.Clamp(zero, x, one); + + std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, + 0.9, 0.0, 0.1, 0.0, 0.6}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { + ComputationBuilder builder(client_, TestName()); + auto zero = builder.ConstantR1({0.0f, 0.0f}); + auto one = builder.ConstantR1({1.0f, 1.0f}); + auto x = builder.ConstantR1({2.1, -2.6}); + auto clamp = builder.Clamp(zero, x, one); + + std::vector expected = {1.0, 0.0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { + ComputationBuilder builder(client_, TestName()); + auto one = builder.ConstantR0(1); + auto two = builder.ConstantR0(2); + auto x = builder.ConstantR1( + {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + auto clamp = builder.Clamp(one, x, two); + + std::vector expected = {2.0, 1.0, 2.0, 1.0, 2.0, + 1.0, 1.0, 1.0, 1.0, 1.0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(VecOpsSimpleTest, MapTenValues) { + Computation add_half; + { + // add_half(x) = x + 0.5 + ComputationBuilder builder(client_, "add_half"); + auto x_value = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); + auto half = builder.ConstantR0(0.5); + builder.Add(x_value, half); + auto computation_status = builder.Build(); + ASSERT_IS_OK(computation_status.status()); + add_half = computation_status.ConsumeValueOrDie(); + } + + Computation clamp; + { + // clamp(y) = clamp<0,5>(y) + ComputationBuilder builder(client_, "clamp"); + auto y_value = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); + auto zero = builder.ConstantR0(0.0); + auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0(5)); + auto computation_status = builder.Build(); + ASSERT_IS_OK(computation_status.status()); + clamp = computation_status.ConsumeValueOrDie(); + } + + Computation mult_relu_add; + { + // mult_relu_add(z) = clamp(add_half(2 * max(z, 0))) + ComputationBuilder builder(client_, "mult_relu_add"); + auto z_value = + builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); + auto zero = builder.ConstantR0(0.0); + auto two = builder.ConstantR0(2.0); + auto max = builder.Max(z_value, zero); + auto mult = builder.Mul(two, max); + auto inner = builder.Map({mult}, add_half); + builder.Map({inner}, clamp); + auto computation_status = builder.Build(); + ASSERT_IS_OK(computation_status.status()); + mult_relu_add = computation_status.ConsumeValueOrDie(); + } + + ComputationBuilder builder(client_, "map10"); + { + auto x = builder.ConstantR1( + {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto activations = builder.Map({x}, mult_relu_add); + } + + std::vector expected = {4.7, 0.5, 5.0, 0.5, 4.7, + 5.0, 0.5, 0.5, 0.5, 3.7}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto y = builder.ConstantR0(3); + builder.Rem(x, y); + + std::vector expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({false, true}); + auto y = builder.ConstantR1({true, false}); + builder.Eq(x, y); + + std::array expected = {{false, false}}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({false, true}); + auto y = builder.ConstantR1({true, false}); + builder.Ne(x, y); + + std::array expected = {{true, true}}; + ComputeAndCompareR1(&builder, expected, {}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc new file mode 100644 index 0000000000..7820bc363d --- /dev/null +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -0,0 +1,395 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace { + +class WhileTest : public ClientLibraryTestBase {}; + +// Tests a while node when the result type T is S32. +// +// int32 result = 0; +// while (result < 5) { +// result = result + 1; +// } +TEST_F(WhileTest, WhileWithScalarResult) { + auto result_shape = ShapeUtil::MakeShape(S32, {}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(5), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR0(1); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.ConstantR0(0); + auto result = builder.While(condition, body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 5, {}); +} + +// Tests a while node when the result type T is a vector. +// +// All constants are chosen to produce exact results. +// vector result(0); +// while (result.sum() < 15.5f) { +// result = result + vector(0); +// } +// TODO(b/29185393): does not terminate on CPU. +TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { + Shape result_shape = ShapeUtil::MakeShape(F32, {0}); + + // Create a computation for the reduction. + Computation add; + { + ComputationBuilder builder(client_, "add"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Add(x, y); + add = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the condition. + // Repeat until the sum of the result vector is less than 15.5f. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, + /*dimensions_to_reduce=*/{0}); + auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add a constant vector of 1.f to the result vector. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR1({}); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.ConstantR1({}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.0001)); +} + +// Tests a while node when the result type T is a vector. +// +// All constants are chosen to produce exact results. +// vector result(8, 0.0f); +// while (result.sum() < 15.5f) { +// result = result + vector(8, 0.125f); +// } +TEST_F(WhileTest, WhileWithVectorResult) { + Shape result_shape = ShapeUtil::MakeShape(F32, {8}); + + // Create a computation for the reduction. + Computation add; + { + ComputationBuilder builder(client_, "add"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder.Add(x, y); + add = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the condition. + // Repeat until the sum of the result vector is less than 5.5f. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, + /*dimensions_to_reduce=*/{0}); + auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add a constant vector of 1.f to the result vector. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR1(8, 0.125f); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.ConstantR1(8, 0.f); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + // Individual elements with increase by 1/8 each time through the loop, so + // the sum will increase by 1.0. It will first be >15.5 when the elements + // have all reached 2.0. + std::vector expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Tests a while node when the result type T is a Tuple. +// +// tuple> result(0, vector(10, 0.0f)); +// while (get<0>(result) < 5) { +// get<0>(result) = get<0>(result) + 1; +// get<1>(result) = get<1>(result) + vector(10, 1.0f); +// } +TEST_F(WhileTest, WhileWithTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( + {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + +// Tests a while node when the result type T is a vector of S32. +// +// int32 result = (0, 0, 0, 0, 0, 0); +// while (result[0] < count) { +// result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]); +// } +// +// This test misuses a vector to represent a pair: +// ((iteration, (random vector))). +// +// Note: this test currently only tests generating random values within a loop. +// Per backend the values generated can be different as the different backends +// use different random number generators. +// TODO(b/32240857): Extend test to verify outputs. +TEST_F(WhileTest, WhileWithPrngScalarResult) { + auto v6s32 = ShapeUtil::MakeShape(S32, {6}); + + // Create a computation for the condition: repeat for count iterations. + auto build_condition = [this, v6s32](int count) { + ComputationBuilder builder(client_, TestName()); + auto prev = builder.Reshape( + builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {}); + builder.Gt(builder.ConstantR0(count), prev); + return builder.Build().ConsumeValueOrDie(); + }; + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, v6s32, "prev"); + auto inc = builder.ConcatInDim( + {builder.ConstantR1({1}), + builder.RngUniform(builder.ConstantR0(0), + builder.ConstantR0(100), + ShapeUtil::MakeShape(S32, {5}))}, + 0); + auto result = builder.Add(inc, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + auto while_loop = [this, &body, build_condition](int count) { + ComputationBuilder builder(client_, TestName()); + auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); + auto result = builder.While(build_condition(count), body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + return builder.Build(); + }; + + for (int i = 1; i < 4; ++i) { + TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i)); + TF_ASSIGN_OR_ASSERT_OK(auto result, + client_->ExecuteAndTransfer(computation, {}, nullptr, + nullptr, /*seed=*/65)); + } +} + +void BM_WhileLoop(int num_iters) { + // Benchmark a simple kernel to measure while loop overheads. + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + LocalClient* client = + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + + Shape loop_state_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}); + + // Create while condition computation with 'loop_limit'. + const int32 loop_limit = 100; + Computation condition; + { + ComputationBuilder builder(client, "condition"); + auto prev = builder.Parameter(0, loop_state_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(loop_limit)); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create while body computation with unit loop increment. + Computation body; + { + ComputationBuilder builder(client, "body"); + auto prev = builder.Parameter(0, loop_state_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto one = builder.ConstantR0(1); + auto next_iteration = builder.Add(iteration, one); + auto one_vec = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, one_vec); + auto result = builder.Tuple({next_iteration, new_weights}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While instruction. + ComputationBuilder builder(client, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + builder.While(condition, body, init); + auto computation = builder.Build().ConsumeValueOrDie(); + + // Run some warm-up executions. + LocalExecuteOptions options; + options.set_allocator(&allocator); + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + auto result = client->ExecuteLocally(computation, {}, options); + ASSERT_TRUE(result.ok()); + } + + // Run benchmark. + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = client->ExecuteLocally(computation, {}, options); + ASSERT_TRUE(result.ok()); + } +} + +// TODO(b/32470510): Benchmark fails on parallel CPU backend. +#ifndef XLA_TEST_BACKEND_CPU_PARALLEL +BENCHMARK(BM_WhileLoop); +#endif + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc new file mode 100644 index 0000000000..7876272467 --- /dev/null +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/text_literal_reader.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr> TextLiteralReader::ReadPath( + tensorflow::StringPiece path) { + CHECK(!path.ends_with(".gz")) + << "TextLiteralReader no longer supports reading .gz files"; + std::unique_ptr file; + Status s = + tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file); + if (!s.ok()) { + return s; + } + + TextLiteralReader reader(file.release()); + return reader.ReadAllLines(); +} + +TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) + : file_(file) {} + +namespace { +// This is an optimized version of tensorflow::str_util::Split which uses +// StringPiece for the delimited strings and uses an out parameter for the +// result to avoid vector creation/destruction. +void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim, + std::vector* result) { + result->clear(); + + if (text.empty()) { + return; + } + + // The following loop is a little strange: its bound is text.size() + 1 + // instead of the more typical text.size(). + // The final iteration of the loop (when i is equal to text.size()) handles + // the trailing token. + size_t token_start = 0; + for (size_t i = 0; i < text.size() + 1; i++) { + if (i == text.size() || text[i] == delim) { + tensorflow::StringPiece token(text.data() + token_start, i - token_start); + result->push_back(token); + token_start = i + 1; + } + } +} +} // namespace + +StatusOr> TextLiteralReader::ReadAllLines() { + tensorflow::io::RandomAccessInputStream stream(file_.get()); + tensorflow::io::BufferedInputStream buf(&stream, 65536); + string shape_string; + Status s = buf.ReadLine(&shape_string); + if (!s.ok()) { + return s; + } + + tensorflow::StringPiece sp(shape_string); + if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { + string tmp = sp.ToString(); + shape_string = tmp; + } + TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); + if (shape.element_type() != F32) { + return Unimplemented( + "unsupported element type for text literal reading: %s", + ShapeUtil::HumanString(shape).c_str()); + } + + auto result = MakeUnique(); + const float fill = std::numeric_limits::quiet_NaN(); + LiteralUtil::PopulateWithValue(fill, AsInt64Slice(shape.dimensions()), + result.get()); + std::vector pieces; + std::vector coordinates; + std::vector coordinate_values; + string line; + while (buf.ReadLine(&line).ok()) { + SplitByDelimToStringPieces(line, ':', &pieces); + tensorflow::StringPiece coordinates_string = pieces[0]; + tensorflow::StringPiece value_string = pieces[1]; + tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); + tensorflow::str_util::RemoveWhitespaceContext(&value_string); + if (!coordinates_string.Consume("(")) { + return InvalidArgument( + "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); + } + if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) { + return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", + line.c_str()); + } + float value; + if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(), + &value)) { + return InvalidArgument("could not parse value as float: \"%s\"", + value_string.ToString().c_str()); + } + SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); + coordinate_values.clear(); + for (tensorflow::StringPiece piece : coordinates) { + int64 coordinate_value; + if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { + return InvalidArgument( + "could not parse coordinate member as int64: \"%s\"", + piece.ToString().c_str()); + } + coordinate_values.push_back(coordinate_value); + } + if (coordinate_values.size() != shape.dimensions_size()) { + return InvalidArgument( + "line did not have expected number of coordinates; want %d got %zu: " + "\"%s\"", + shape.dimensions_size(), coordinate_values.size(), line.c_str()); + } + LiteralUtil::Set(result.get(), coordinate_values, value); + } + return std::move(result); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h new file mode 100644 index 0000000000..3cfbb2c7fb --- /dev/null +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_READER_H_ +#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_READER_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Reads a textual literal from a file path. The format of the file must be: +// +// f32[1,2,3,4] +// (0, 0, 0, 0): 1.234 +// (0, 0, 0, 1): 0xf00p-2 +// ... +// +// Note that for floating values the hex output (as in the second value above) +// will more precisely convey the exact values. +class TextLiteralReader { + public: + // See class comment -- reads a file in its entirety (there must be only one + // literal in the text file path provided). + static StatusOr> ReadPath( + tensorflow::StringPiece path); + + private: + // Ownership of file is transferred. + explicit TextLiteralReader(tensorflow::RandomAccessFile* file); + + // Parses a shape string on the first line, followed by lines of values to the + // end of the file. + StatusOr> ReadAllLines(); + + // Owns the file being read + std::unique_ptr file_; + + TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralReader); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_READER_H_ diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc new file mode 100644 index 0000000000..94d0f2646b --- /dev/null +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/text_literal_reader.h" + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(TextLiteralReaderTest, ReadsR3File) { + string contents = R"(f32[1,2,3] +(0,0,0): 42.5 +(0,0,1): 43.5 +(0,0,2): 44.5 +(0,1,0): 45.5 +(0,1,1): 46.5 +(0,1,2): 47.5 +)"; + + string fname = tensorflow::testing::TmpDir() + "/ReadsR3File.data.txt"; + EXPECT_TRUE( + tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents) + .ok()); + + std::unique_ptr literal = + TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); + EXPECT_EQ(42.5, LiteralUtil::Get(*literal, {0, 0, 0})); + EXPECT_EQ(43.5, LiteralUtil::Get(*literal, {0, 0, 1})); + EXPECT_EQ(44.5, LiteralUtil::Get(*literal, {0, 0, 2})); + EXPECT_EQ(45.5, LiteralUtil::Get(*literal, {0, 1, 0})); + EXPECT_EQ(46.5, LiteralUtil::Get(*literal, {0, 1, 1})); + EXPECT_EQ(47.5, LiteralUtil::Get(*literal, {0, 1, 2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc new file mode 100644 index 0000000000..a5097e41cb --- /dev/null +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/text_literal_writer.h" + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( + const Literal& literal, tensorflow::StringPiece path) { + std::unique_ptr f; + auto s = tensorflow::Env::Default()->NewWritableFile(path.ToString(), &f); + if (!s.ok()) { + return s; + } + + s = f->Append(ShapeUtil::HumanString(literal.shape()) + "\n"); + if (!s.ok()) { + return s; + } + + tensorflow::Status status; + tensorflow::WritableFile* f_ptr = f.get(); + LiteralUtil::EachCellAsString( + literal, [f_ptr, &status](tensorflow::gtl::ArraySlice indices, + const string& value) { + if (!status.ok()) { + return; + } + string coordinates = tensorflow::strings::StrCat( + "(", tensorflow::str_util::Join(indices, ", "), ")"); + + status = f_ptr->Append( + tensorflow::strings::StrCat(coordinates, ": ", value, "\n")); + }); + auto ignored = f->Close(); + return status; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h new file mode 100644 index 0000000000..545bd22da9 --- /dev/null +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Writes a literal to textual form at a file path. +// +// The format is roughly: +// +// f32[1,2,3,4] +// (0, 0, 0, 0): 1.234 +// (0, 0, 0, 1): 0xf00p-2 +// ... +// +// This should be readable by xla::TextLiteralReader. +class TextLiteralWriter { + public: + static tensorflow::Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc new file mode 100644 index 0000000000..9dce4d13bb --- /dev/null +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/text_literal_writer.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(TextLiteralWriterTest, WritesFloatLiteral) { + auto literal = LiteralUtil::CreateR2({ + {3.14, 2.17}, {1.23, 4.56}, + }); + string path = + tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); + ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path)); + string contents; + TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, + &contents)); + const string expected = R"(f32[2,2] +(0, 0): 3.14 +(0, 1): 2.17 +(1, 0): 1.23 +(1, 1): 4.56 +)"; + EXPECT_EQ(expected, contents); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD new file mode 100644 index 0000000000..46eab7f02b --- /dev/null +++ b/tensorflow/compiler/xla/tools/BUILD @@ -0,0 +1,191 @@ +# Tools and utilities that aid in XLA development and usage. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow/compiler/xla:internal"]) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), + visibility = ["//tensorflow/compiler/xla:internal"], +) + +cc_binary( + name = "hex_floats_to_packed_literal", + srcs = ["hex_floats_to_packed_literal.cc"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "dumped_computation_to_graphviz_library", + srcs = ["dumped_computation_to_graphviz.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "dumped_computation_to_graphviz", + deps = [ + ":dumped_computation_to_graphviz_library", + ], +) + +cc_binary( + name = "show_signature", + srcs = ["show_signature.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "replay_computation_library", + srcs = ["replay_computation.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + alwayslink = True, +) + +cc_binary( + name = "replay_computation_cpu", + deps = [ + ":replay_computation_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], +) + +cc_binary( + name = "replay_computation_gpu", + deps = [ + ":replay_computation_library", + "//tensorflow/compiler/xla/service:gpu_plugin", + ], +) + +cc_binary( + name = "show_literal", + srcs = ["show_literal.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "convert_computation", + srcs = ["convert_computation.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "show_text_literal", + srcs = ["show_text_literal.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:text_literal_reader", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "dumped_computation_to_text", + srcs = ["dumped_computation_to_text.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "dumped_computation_to_operation_list", + srcs = ["dumped_computation_to_operation_list.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc new file mode 100644 index 0000000000..fe03a6e7bd --- /dev/null +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: convert_computation serialized_computation_proto +// +// bin2txt spits out the result to stdout. txt2bin modifies the file in place. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { +namespace tools { + +void RealMain(const string& mode, const string& path) { + SessionModule module; + tensorflow::Env* env = tensorflow::Env::Default(); + if (mode == "txt2bin") { + TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); + TF_CHECK_OK(tensorflow::WriteBinaryProto(env, path, module)); + } else if (mode == "bin2txt") { + TF_CHECK_OK(tensorflow::ReadBinaryProto(env, path, &module)); + string out; + tensorflow::protobuf::TextFormat::PrintToString(module, &out); + fprintf(stdout, "%s", out.c_str()); + } else { + LOG(QFATAL) << "unknown mode for computation conversion: " << mode; + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + QCHECK_EQ(argc, 3) << "usage: " << argv[0] << " "; + xla::tools::RealMain(argv[1], argv[2]); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc new file mode 100644 index 0000000000..10efa9f3e8 --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: dumped_computation_to_graphviz some_binary_snapshot_proto* +// +// Dumps a graphviz URL for a snapshot computation to the command line. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. +// +// The GraphViz URL is placed into the log stderr, whereas computation +// statistics are printed on stdout (implementation note: getting computation +// statistics is how we trigger compilation to split out a GraphViz URL). + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace tools { + +void RealMain(tensorflow::gtl::ArraySlice args) { + Client* client = ClientLibrary::LocalClientOrDie(); + for (char* arg : args) { + SessionModule module; + TF_CHECK_OK( + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); + Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + ComputationStats stats = + client->GetComputationStats(computation).ConsumeValueOrDie(); + fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); + flags->xla_generate_hlo_graph = ".*"; + flags->xla_hlo_graph_layout = true; + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc new file mode 100644 index 0000000000..4c242abc9b --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Dumps out the operations that are present in a serialized computation. + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace tools { + +class OperationDumper : public DfsHloVisitorWithDefault { + public: + explicit OperationDumper(const string& path) : path_(path) {} + + Status DefaultAction(HloInstruction* hlo) override { + string params = tensorflow::str_util::Join( + hlo->operands(), ", ", [](string* out, const HloInstruction* operand) { + tensorflow::strings::StrAppend( + out, ShapeUtil::HumanString(operand->shape())); + }); + // Spit `op_name(params...) -> result_type :: path` to stdout. + std::cout << tensorflow::strings::Printf( + "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(), + params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(), + path_.c_str()); + return Status::OK(); + } + + private: + string path_; +}; + +void RealMain(tensorflow::gtl::ArraySlice args) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + LocalService* local_service = + ClientLibrary::GetXlaService(client->platform()); + for (char* arg : args) { + SessionModule session_module; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, + &session_module)); + auto computation_status = client->LoadSnapshot(session_module); + if (!computation_status.ok()) { + fprintf(stderr, "could not load snapshot for %s: %s\n", arg, + computation_status.status().ToString().c_str()); + continue; + } + Computation computation = computation_status.ConsumeValueOrDie(); + + std::unique_ptr program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); + + std::vector layouts; + for (int i = 0; i < program_shape->parameters_size(); ++i) { + layouts.push_back(&program_shape->parameters(i)); + } + StatusOr> executable = + local_service->CompileExecutable( + computation.handle(), layouts, &program_shape->result(), + /*device_ordinal=*/0, /*has_hybrid_result=*/true); + + const HloModule& module = executable.ValueOrDie()->module(); + + OperationDumper dumper(arg); + for (auto& computation : module.computations()) { + TF_CHECK_OK(computation->root_instruction()->Accept(&dumper)); + } + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc new file mode 100644 index 0000000000..8b96e13489 --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace tools { + +void RealMain(tensorflow::gtl::ArraySlice args) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + LocalService* local_service = + ClientLibrary::GetXlaService(client->platform()); + for (char* arg : args) { + SessionModule session_module; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, + &session_module)); + auto computation_status = client->LoadSnapshot(session_module); + if (!computation_status.ok()) { + fprintf(stderr, "could not load snapshot for %s: %s\n", arg, + computation_status.status().ToString().c_str()); + continue; + } + Computation computation = computation_status.ConsumeValueOrDie(); + + std::unique_ptr program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); + + std::vector layouts; + for (int i = 0; i < program_shape->parameters_size(); ++i) { + layouts.push_back(&program_shape->parameters(i)); + } + StatusOr> executable = + local_service->CompileExecutable( + computation.handle(), layouts, &program_shape->result(), + /*device_ordinal=*/0, /*has_hybrid_result=*/true); + + const HloModule& module = executable.ValueOrDie()->module(); + + fprintf(stdout, "HLO for %s backend:\n%s\n", + local_service->backend().platform()->Name().c_str(), + module.ToString().c_str()); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc new file mode 100644 index 0000000000..eb7bff053b --- /dev/null +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +using xla::string; + +int main(int argc, char** argv) { + // Flags + string input_file = ""; + string output_file = ""; + const std::vector flag_list = { + tensorflow::Flag("input_file", &input_file, "file to convert"), + tensorflow::Flag("output_file", &output_file, "converted file"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 1 || !parse_ok) { + LOG(QFATAL) << usage; + } + + if (input_file.empty()) { + LOG(QFATAL) << "--input_file is required"; + } + if (output_file.empty()) { + LOG(QFATAL) << "--output_file is required"; + } + + std::unique_ptr file; + TF_CHECK_OK( + tensorflow::Env::Default()->NewRandomAccessFile(input_file, &file)); + + std::vector floats; + string line; + tensorflow::io::RandomAccessInputStream stream(file.get()); + tensorflow::io::BufferedInputStream buf(&stream, 1048576); + while (buf.ReadLine(&line).ok()) { + float value; + QCHECK(sscanf(line.c_str(), "%f", &value) != 1) << "invalid float value: " + << line; + floats.push_back(value); + } + + tensorflow::StringPiece content( + tensorflow::bit_cast(floats.data()), + floats.size() * sizeof(float)); + TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), + output_file, content)); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc new file mode 100644 index 0000000000..ffb2d5aefb --- /dev/null +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: replay_computation some_binary_snapshot_proto* +// +// Replays computations and shows the results on the command line. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. +// +// Computations that require arguments can be replayed using fake data by +// passing --use_fake_data on the command line. If the real data is available +// in the proto and --use_fake_data is false, the real data is used. +// +// The output format is: +// +// file_path: computation_name :: type:literal_str + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/testing.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace tools { + +// Invokes the given computation passing arbitrary data for every (unbound) +// parameter if use_fake_data, Otherwise use recorded data if available. +StatusOr> ReplayComputation( + const SessionModule& module, bool use_fake_data, Client* client) { + TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); + + std::vector> arguments; + if (use_fake_data) { + arguments = MakeFakeArgumentsOrDie(computation, client); + } else { // use recorded data if available + for (const Literal& literal : module.arguments()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr data, + client->TransferToServer(literal)); + arguments.push_back(std::move(data)); + } + } + + std::vector execute_arguments; + for (auto& argument : arguments) { + execute_arguments.push_back(argument.get()); + } + return client->ExecuteAndTransfer(computation, execute_arguments); +} + +void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { + Client* client = ClientLibrary::LocalClientOrDie(); + tensorflow::Env* env = tensorflow::Env::Default(); + for (char* arg : args) { + SessionModule module; + TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); + StatusOr> result_status = + ReplayComputation(module, use_fake_data, client); + if (!result_status.ok()) { + fprintf(stderr, "%s: error: %s\n", arg, + result_status.status().ToString().c_str()); + continue; + } + std::unique_ptr result = result_status.ConsumeValueOrDie(); + fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), + ShapeUtil::HumanString(result->shape()).c_str(), + LiteralUtil::ToString(*result).c_str()); + if (module.has_result()) { + fprintf(stdout, "was %s:%s\n", + ShapeUtil::HumanString(module.result().shape()).c_str(), + LiteralUtil::ToString(module.result()).c_str()); + } + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + // Flags + bool use_fake_data = false; + const std::vector flag_list = { + tensorflow::Flag("use_fake_data", &use_fake_data, + "Replay computation using fake data"), + }; + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc < 2 || !parse_ok) { + LOG(QFATAL) << usage; + } + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args, use_fake_data); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc new file mode 100644 index 0000000000..cf363913b1 --- /dev/null +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: show_literal +// +// Dumps out the Literal::ToString of a tensorflow::WriteBinaryProto format +// Literal serialized on disk. + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +int main(int argc, char **argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + if (argc < 2) { + LOG(QFATAL) << "Usage: " << argv[0] + << " "; + } + + xla::Literal literal; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], + &literal)); + LOG(INFO) << "literal: " << literal.ShortDebugString(); + fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); +} diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc new file mode 100644 index 0000000000..1f3340cbc6 --- /dev/null +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: show_signature some_binary_snapshot_proto* +// +// Shows the signature (ProgramShape) of binary snapshot proto(s) on the command +// line. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. +// +// The output format is: +// +// file_path: computation_name :: program_shape_str + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace tools { + +void RealMain(tensorflow::gtl::ArraySlice args) { + Client* client = ClientLibrary::LocalClientOrDie(); + for (char* arg : args) { + SessionModule module; + TF_CHECK_OK( + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); + Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + std::unique_ptr shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); + fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + ShapeUtil::HumanString(*shape).c_str()); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc new file mode 100644 index 0000000000..2d983b407c --- /dev/null +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: show_text_literal + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/text_literal_reader.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +int main(int argc, char **argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + if (argc < 2) { + LOG(QFATAL) << "Usage: " << argv[0] << " "; + } + + std::unique_ptr literal = + xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); + + LOG(INFO) << "literal: " << literal->ShortDebugString(); + fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(*literal).c_str()); + if (literal->shape().element_type() == xla::F32) { + float min = + *std::min_element(literal->f32s().begin(), literal->f32s().end()); + float max = + *std::max_element(literal->f32s().begin(), literal->f32s().end()); + fprintf(stderr, "min: %a=%f\n", min, min); + fprintf(stderr, "max: %a=%f\n", max, max); + } +} diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h new file mode 100644 index 0000000000..8258031a2c --- /dev/null +++ b/tensorflow/compiler/xla/types.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ +#define TENSORFLOW_COMPILER_XLA_TYPES_H_ + +#include "tensorflow/core/platform/types.h" + +namespace xla { + +using ::tensorflow::string; + +using ::tensorflow::int8; +using ::tensorflow::int16; +using ::tensorflow::int32; +using ::tensorflow::int64; + +using ::tensorflow::uint8; +using ::tensorflow::uint16; +using ::tensorflow::uint32; +using ::tensorflow::uint64; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc new file mode 100644 index 0000000000..d23002c1a0 --- /dev/null +++ b/tensorflow/compiler/xla/util.cc @@ -0,0 +1,238 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/util.h" + +#include + +#include "tensorflow/compiler/xla/legacy_flags/util_flags.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stacktrace.h" + +namespace xla { +namespace { + +// Adds a backtrace to the provided status iff the xla_status_add_backtrace flag +// is set. This is useful for quickly tracing status errors observed coming out +// of the service. +Status MaybeAddBacktrace(Status prior) { + DCHECK(!prior.ok()); + if (legacy_flags::GetUtilFlags()->xla_status_add_backtrace) { + return Status{prior.code(), + tensorflow::strings::StrCat(prior.error_message(), " :: ", + tensorflow::CurrentStackTrace())}; + } else { + return prior; + } +} + +} // namespace + +ScopedLoggingTimer::ScopedLoggingTimer(const string& label, int32 vlog_level) + : label(label), vlog_level(vlog_level) { + if (VLOG_IS_ON(vlog_level)) { + start_micros = tensorflow::Env::Default()->NowMicros(); + } +} + +ScopedLoggingTimer::~ScopedLoggingTimer() { + if (VLOG_IS_ON(vlog_level)) { + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + double secs = (end_micros - start_micros) / 1000000.0; + + LOG(INFO) << label << " time: " + << tensorflow::strings::HumanReadableElapsedTime(secs); + } +} + +Status AddStatus(Status prior, tensorflow::StringPiece context) { + CHECK(!prior.ok()); + return Status{prior.code(), tensorflow::strings::StrCat( + context, ": ", prior.error_message())}; +} + +Status AppendStatus(Status prior, tensorflow::StringPiece context) { + CHECK(!prior.ok()); + return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(), + ": ", context)}; +} + +// Implementation note: we can't common these out (without using macros) because +// they all need to va_start/va_end their varargs in their frame. + +Status InvalidArgument(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::InvalidArgument(message)); +} + +Status Unimplemented(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::Unimplemented(message)); +} + +Status InternalError(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::Internal(message)); +} + +Status FailedPrecondition(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::FailedPrecondition(message)); +} + +Status ResourceExhausted(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::ResourceExhausted(message)); +} + +Status NotFound(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::NotFound(message)); +} + +Status Unavailable(const char* format, ...) { + string message; + va_list args; + va_start(args, format); + tensorflow::strings::Appendv(&message, format, args); + va_end(args); + return MaybeAddBacktrace(tensorflow::errors::Unavailable(message)); +} + +string Reindent(tensorflow::StringPiece original, + const tensorflow::StringPiece indentation) { + std::vector pieces = tensorflow::str_util::Split( + tensorflow::StringPiece(original.data(), original.size()), '\n'); + return tensorflow::str_util::Join( + pieces, "\n", [indentation](string* out, string s) { + tensorflow::StringPiece piece(s); + tensorflow::str_util::RemoveWhitespaceContext(&piece); + tensorflow::strings::StrAppend(out, indentation, piece); + }); +} + +std::vector InversePermutation( + tensorflow::gtl::ArraySlice input_permutation) { + std::vector output_permutation(input_permutation.size(), -1); + for (size_t i = 0; i < input_permutation.size(); ++i) { + output_permutation[input_permutation[i]] = i; + } + DCHECK_EQ( + 0, std::count(output_permutation.begin(), output_permutation.end(), -1)); + DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(), + output_permutation.begin())); + return output_permutation; +} + +std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, + tensorflow::gtl::ArraySlice p2) { + CHECK_EQ(p1.size(), p2.size()); + std::vector output; + for (size_t i = 0; i < p1.size(); ++i) { + output.push_back(p1[p2[i]]); + } + return output; +} + +int64 PositionInContainer(tensorflow::gtl::ArraySlice container, + int64 value) { + return std::find(container.begin(), container.end(), value) - + container.begin(); +} + +PaddingConfig MakeNoPaddingConfig(int64 rank) { + PaddingConfig padding_config; + for (int64 dnum = 0; dnum < rank; ++dnum) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + } + return padding_config; +} + +string HumanReadableNumFlops(double flops, double nanoseconds) { + if (nanoseconds == 0) { + return "NaN FLOP/s"; + } + double nano_flops = flops / nanoseconds; + string throughput = tensorflow::strings::HumanReadableNum( + static_cast(nano_flops * 1e9)); + tensorflow::StringPiece sp(throughput); + // Use the more common "G(FLOPS)", rather than "B(FLOPS)" + if (sp.ends_with("B") || // Ends in 'B', ignoring case + sp.ends_with("b")) { + *throughput.rbegin() = 'G'; + } + throughput += "FLOP/s"; + return throughput; +} + +void LogLines(int sev, tensorflow::StringPiece text, const char* fname, + int lineno) { + const int orig_sev = sev; + if (sev == tensorflow::FATAL) { + sev = tensorflow::ERROR; + } + + size_t cur = 0; + while (cur < text.size()) { + size_t eol = text.find('\n', cur); + if (eol == tensorflow::StringPiece::npos) { + eol = text.size(); + } + auto msg = text.substr(cur, eol - cur); + tensorflow::internal::LogString(fname, lineno, sev, + string(msg.data(), msg.size())); + cur = eol + 1; + } + + if (orig_sev == tensorflow::FATAL) { + tensorflow::internal::LogString(fname, lineno, orig_sev, + "Aborting due to errors."); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h new file mode 100644 index 0000000000..137c613e6f --- /dev/null +++ b/tensorflow/compiler/xla/util.h @@ -0,0 +1,257 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generally useful utility functions that are common to (not specific to any +// given part of) the XLA code base. + +#ifndef TENSORFLOW_COMPILER_XLA_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// RAII timer that logs with a given label the wall clock time duration in human +// readable form. This differs from base's ElapsedTimer primarily in that it +// spits out the human-readable duration form. +struct ScopedLoggingTimer { + explicit ScopedLoggingTimer(const string& label, int32 vlog_level = 1); + ~ScopedLoggingTimer(); + + uint64 start_micros; + string label; + int32 vlog_level; +}; + +// Given a vector, returns a MutableArraySlice that points at its +// internals. +// +// Warning: if the vector is updated its storage pointer may change, so use this +// with caution (ideally in limited scopes with temporary lifetimes). +template +tensorflow::gtl::MutableArraySlice MutableByteSlice(std::vector* v) { + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(v->data()), v->size() * sizeof(T)); +} + +// Turns an immutable slice of type T into an immutable slice of bytes with the +// same byte size. +template +tensorflow::gtl::ArraySlice CastToByteSlice( + tensorflow::gtl::ArraySlice slice) { + return tensorflow::gtl::ArraySlice( + reinterpret_cast(slice.data()), slice.size() * sizeof(T)); +} + +// Casts a byte slice to a non-byte type T, checking that the original slice +// length is a multiple of sizeof(T). +template +tensorflow::gtl::ArraySlice CastByteSlice( + tensorflow::gtl::ArraySlice slice) { + CHECK_EQ(0, slice.size() % sizeof(T)); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(slice.data()), slice.size() / sizeof(T)); +} + +// Convenience function to force a vector to convert to an immutable slice. +template +tensorflow::gtl::ArraySlice AsSlice(const std::vector& v) { + return tensorflow::gtl::ArraySlice(v); +} + +// Converts a mutable vector pointer into a MutableArraySlice of the same +// type. +template +tensorflow::gtl::MutableArraySlice AsMutableSlice(std::vector* v) { + return tensorflow::gtl::MutableArraySlice(v->data(), v->size()); +} + +// xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source. +// Wrapper function that gives an int64 array slice view of a repeated int64 +// protobuf field. +static inline tensorflow::gtl::ArraySlice AsInt64Slice( + const tensorflow::protobuf::RepeatedField& v) { + tensorflow::gtl::ArraySlice slice(v); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(slice.data()), slice.size()); +} + +// As above, but for uint64 types. +static inline tensorflow::gtl::ArraySlice AsUInt64Slice( + const tensorflow::protobuf::RepeatedField& v) { + tensorflow::gtl::ArraySlice slice(v); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(slice.data()), slice.size()); +} + +// Compares two containers for equality. Returns true iff the two containers +// have the same size and all their elements compare equal using their +// operator==. Like std::equal, but forces size equality. +template +bool ContainersEqual(const Container1T& c1, const Container2T& c2) { + return ((c1.size() == c2.size()) && + std::equal(std::begin(c1), std::end(c1), std::begin(c2))); +} + +// Compares two containers for equality. Returns true iff the two containers +// have the same size and all their elements compare equal using the predicate +// p. Like std::equal, but forces size equality. +template +bool ContainersEqual(const Container1T& c1, const Container2T& c2, + PredicateT p) { + return ((c1.size() == c2.size()) && + std::equal(std::begin(c1), std::end(c1), std::begin(c2), p)); +} + +// Adds some context information to the error message in a +// Status. This is useful as Statuses are +// propagated upwards. +Status AddStatus(Status prior, tensorflow::StringPiece context); +Status AppendStatus(Status prior, tensorflow::StringPiece context); + +// Status error shorthands -- printfs the arguments to be +// used as an error message and returns a status in the canonical +// error space. +Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); + +// Splits the lines of the original, replaces leading whitespace with the prefix +// given by "indentation", and returns the string joined by newlines again. As a +// side effect, any additional trailing whitespace is removed. +// +// Note: even different amounts of leading whitespace on different lines will be +// uniformly replaced with "indentation". +string Reindent(tensorflow::StringPiece original, + tensorflow::StringPiece indentation); + +// Applies `permutation` on `input` and returns the permuted array. +// For each i, output[permutation[i]] = input[i]. +// +// Precondition: +// 1. `permutation` is a permutation of 0..permutation.size()-1. +// 2. permutation.size() == input.size(). +template